QuickStart |安裝|用戶指南|示例| FP8收斂|集成|發行說明

變壓器引擎(TE)是一個庫,用於加速NVIDIA GPU上的變壓器模型,包括在Hopper GPU上使用8位浮點(FP8)精度,以在訓練和推理中提供較低的內存利用,以提供更好的性能。 TE為流行的變壓器體系結構提供了一系列高度優化的構建塊和一個自動混合精確的API,可以與您的框架特定代碼無縫使用。 TE還包括一個框架不可知的C ++ API,該API可以與其他深度學習庫集成,以啟用FP8對變壓器的支持。
隨著變壓器模型中的參數數量的數量不斷增長,諸如BERT,GPT和T5等體系結構的訓練和推斷變得非常記憶和計算密集型。默認情況下,大多數深度學習框架都用FP32訓練。但是,對於許多深度學習模型,這並不是至關重要的。與FP32訓練相比,使用混合精確訓練將單精度(FP32)與較低的精度(例如FP16)格式相結合,其準確性差異很小。借助Hopper GPU體系結構FP8精度,它提供了比FP16的改進性能,而準確性沒有降解。儘管所有主要的深度學習框架都支持FP16,但在當今框架中,FP8支持在本地尚未獲得。
TE通過提供與流行的大語言模型(LLM)庫集成的API來解決FP8支持的問題。它提供了一個由模塊組成的Python API,可輕鬆構建變壓器層以及C ++中的框架 - 敏捷庫,包括FP8支持所需的結構和內核。 TE內部提供的模塊可維護FP8培訓所需的縮放因素和其他值,從而大大簡化了用戶的混合精度培訓。
import torch
import transformer_engine . pytorch as te
from transformer_engine . common import recipe
# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
# Initialize model and inputs.
model = te . Linear ( in_features , out_features , bias = True )
inp = torch . randn ( hidden_size , in_features , device = "cuda" )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . E4M3 )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
out = model ( inp )
loss = out . sum ()
loss . backward () import flax
import jax
import jax . numpy as jnp
import transformer_engine . jax as te
import transformer_engine . jax . flax as te_flax
from transformer_engine . common import recipe
BATCH = 32
SEQLEN = 128
HIDDEN = 1024
# Initialize RNG and inputs.
rng = jax . random . PRNGKey ( 0 )
init_rng , data_rng = jax . random . split ( rng )
inp = jax . random . normal ( data_rng , [ BATCH , SEQLEN , HIDDEN ], jnp . float32 )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . HYBRID )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
model = te_flax . DenseGeneral ( features = HIDDEN )
def loss_fn ( params , other_vars , inp ):
out = model . apply ({ 'params' : params , ** other_vars }, inp )
return jnp . mean ( out )
# Initialize models.
variables = model . init ( init_rng , inp )
other_variables , params = flax . core . pop ( variables , 'params' )
# Construct the forward and backward function
fwd_bwd_fn = jax . value_and_grad ( loss_fn , argnums = ( 0 , 1 ))
for _ in range ( 10 ):
loss , ( param_grads , other_grads ) = fwd_bwd_fn ( params , other_variables , inp )開始使用變壓器引擎的最快方法是在NVIDIA GPU Cloud(NGC)目錄上使用Docker Images。例如,使用NGC Pytorch容器交互式使用,
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3其中23.10是容器版本。例如,2023年10月發布的23.10。
要安裝最新穩定版本的變壓器引擎,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable這將自動檢測是否安裝了任何支持的深度學習框架並為其構建變壓器引擎支持。要明確指定框架,請將環境變量NVTE_FRAMEWORK設置為逗號分隔的列表(例如NVTE_FRAMEWORK = JAX,PYTORCH,PADDLE)。
另外,可以直接從變壓器引擎的PYPI中安裝包裝,例如
pip install transformer_engine[pytorch]為了獲得變壓器引擎的必要的Python綁定,必須在逗號分隔列表中明確指定所需的框架(例如[Jax,Pytorch,Paddle])。變壓器發動機輪芯輪用於核心庫以及槳板擴展。用於JAX和PYTORCH擴展的源分佈。
請參閱“安裝指南”。
變壓器發動機釋放v0.11.0增加了Pytorch中Flashattention-2的支持,以提高性能。
已知的問題是,Flashattention-2彙編是資源密集的,需要大量RAM(請參閱錯誤),這可能會導致變壓器引擎安裝過程中的內存錯誤。請嘗試在環境中設置max_jobs = 1以避免問題。
請注意,NGC Pytorch 23.08+容器包括Flashattention-2。
為了統一在變壓器引擎中所有三個框架中註意力面罩的定義和使用,填充面罩已從真正的含義包含相應的位置包含在我們的Pytorch實現中排除該位置。由於v1.7,所有註意性掩碼類型都遵循相同的定義,其中true意味著掩蓋相應的位置和錯誤的含義,包括注意力計算中的位置。
此更改的一個例子是,
# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
b, b, 0, 0, 0,
c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True, True, True, False, False,
True, True, False, False, False,
True, True, True, True, False]
# and for v1.7 onwards it should be,
[False, False, False, True, True,
False, False, True, True, True,
False, False, False, False, True]FP8已在不同的模型架構和配置中進行了廣泛的測試,我們發現FP8和BF16訓練損失曲線之間沒有顯著差異。 FP8還經過驗證,以確保下游LLM任務的準確性(例如Lambada和Wikitext)。以下是在不同框架之間測試收斂的模型的示例。
| 模型 | 框架 | 來源 |
|---|---|---|
| T5-770m | jax/t5x | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/rosetta/projects/t5x#convergence-and-performance |
| MPT-1.3B | 馬賽克作曲家 | https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1 |
| GPT-5B | JAX/PAXML | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/rosetta/projects/pax#h100-results |
| GPT-5B | NEMO框架 | 可根據要求提供 |
| Llama2-7b | 阿里巴巴·帕伊(Alibaba Pai) | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| T5-11b | jax/t5x | 可根據要求提供 |
| MPT-13B | 馬賽克作曲家 | https://www.databricks.com/blog/turbocharged-training-training-timistimizing-databricks-mosaic-ai stack-stack-fp8 |
| GPT-22B | NEMO框架 | 可根據要求提供 |
| Llama2-70B | 阿里巴巴·帕伊(Alibaba Pai) | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| GPT-175B | JAX/PAXML | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/rosetta/projects/pax#h100-results |
變壓器引擎已與流行的LLM框架集成在一起,例如:
我們歡迎對變壓引擎的貢獻!為了為變壓器引擎做出貢獻並提出拉力請求,請遵循貢獻指南中概述的指南。