QuickStart |安装|用户指南|示例| FP8收敛|集成|发行说明

变压器引擎(TE)是一个库,用于加速NVIDIA GPU上的变压器模型,包括在Hopper GPU上使用8位浮点(FP8)精度,以在训练和推理中提供较低的内存利用,以提供更好的性能。 TE为流行的变压器体系结构提供了一系列高度优化的构建块和一个自动混合精确的API,可以与您的框架特定代码无缝使用。 TE还包括一个框架不可知的C ++ API,该API可以与其他深度学习库集成,以启用FP8对变压器的支持。
随着变压器模型中的参数数量的数量不断增长,诸如BERT,GPT和T5等体系结构的训练和推断变得非常记忆和计算密集型。默认情况下,大多数深度学习框架都用FP32训练。但是,对于许多深度学习模型,这并不是至关重要的。与FP32训练相比,使用混合精确训练将单精度(FP32)与较低的精度(例如FP16)格式相结合,其准确性差异很小。借助Hopper GPU体系结构FP8精度,它提供了比FP16的改进性能,而准确性没有降解。尽管所有主要的深度学习框架都支持FP16,但在当今框架中,FP8支持在本地尚未获得。
TE通过提供与流行的大语言模型(LLM)库集成的API来解决FP8支持的问题。它提供了一个由模块组成的Python API,可轻松构建变压器层以及C ++中的框架 - 敏捷库,包括FP8支持所需的结构和内核。 TE内部提供的模块可维护FP8培训所需的缩放因素和其他值,从而大大简化了用户的混合精度培训。
import torch
import transformer_engine . pytorch as te
from transformer_engine . common import recipe
# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
# Initialize model and inputs.
model = te . Linear ( in_features , out_features , bias = True )
inp = torch . randn ( hidden_size , in_features , device = "cuda" )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . E4M3 )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
out = model ( inp )
loss = out . sum ()
loss . backward () import flax
import jax
import jax . numpy as jnp
import transformer_engine . jax as te
import transformer_engine . jax . flax as te_flax
from transformer_engine . common import recipe
BATCH = 32
SEQLEN = 128
HIDDEN = 1024
# Initialize RNG and inputs.
rng = jax . random . PRNGKey ( 0 )
init_rng , data_rng = jax . random . split ( rng )
inp = jax . random . normal ( data_rng , [ BATCH , SEQLEN , HIDDEN ], jnp . float32 )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . HYBRID )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
model = te_flax . DenseGeneral ( features = HIDDEN )
def loss_fn ( params , other_vars , inp ):
out = model . apply ({ 'params' : params , ** other_vars }, inp )
return jnp . mean ( out )
# Initialize models.
variables = model . init ( init_rng , inp )
other_variables , params = flax . core . pop ( variables , 'params' )
# Construct the forward and backward function
fwd_bwd_fn = jax . value_and_grad ( loss_fn , argnums = ( 0 , 1 ))
for _ in range ( 10 ):
loss , ( param_grads , other_grads ) = fwd_bwd_fn ( params , other_variables , inp )开始使用变压器引擎的最快方法是在NVIDIA GPU Cloud(NGC)目录上使用Docker Images。例如,使用NGC Pytorch容器交互式使用,
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3其中23.10是容器版本。例如,2023年10月发布的23.10。
要安装最新稳定版本的变压器引擎,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable这将自动检测是否安装了任何支持的深度学习框架并为其构建变压器引擎支持。要明确指定框架,请将环境变量NVTE_FRAMEWORK设置为逗号分隔的列表(例如NVTE_FRAMEWORK = JAX,PYTORCH,PADDLE)。
另外,可以直接从变压器引擎的PYPI中安装包装,例如
pip install transformer_engine[pytorch]为了获得变压器引擎的必要的Python绑定,必须在逗号分隔列表中明确指定所需的框架(例如[Jax,Pytorch,Paddle])。变压器发动机轮芯轮用于核心库以及桨板扩展。用于JAX和PYTORCH扩展的源分布。
请参阅“安装指南”。
变压器发动机释放v0.11.0增加了Pytorch中Flashattention-2的支持,以提高性能。
已知的问题是,Flashattention-2汇编是资源密集的,需要大量RAM(请参阅错误),这可能会导致变压器引擎安装过程中的内存错误。请尝试在环境中设置max_jobs = 1以避免问题。
请注意,NGC Pytorch 23.08+容器包括Flashattention-2。
为了统一在变压器引擎中所有三个框架中注意力面罩的定义和使用,填充面罩已从真正的含义包含相应的位置包含在我们的Pytorch实现中排除该位置。由于v1.7,所有注意性掩码类型都遵循相同的定义,其中true意味着掩盖相应的位置和错误的含义,包括注意力计算中的位置。
此更改的一个例子是,
# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
b, b, 0, 0, 0,
c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True, True, True, False, False,
True, True, False, False, False,
True, True, True, True, False]
# and for v1.7 onwards it should be,
[False, False, False, True, True,
False, False, True, True, True,
False, False, False, False, True]FP8已在不同的模型架构和配置中进行了广泛的测试,我们发现FP8和BF16训练损失曲线之间没有显着差异。 FP8还经过验证,以确保下游LLM任务的准确性(例如Lambada和Wikitext)。以下是在不同框架之间测试收敛的模型的示例。
| 模型 | 框架 | 来源 |
|---|---|---|
| T5-770m | jax/t5x | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/rosetta/projects/t5x#convergence-and-performance |
| MPT-1.3B | 马赛克作曲家 | https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1 |
| GPT-5B | JAX/PAXML | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/rosetta/projects/pax#h100-results |
| GPT-5B | NEMO框架 | 可根据要求提供 |
| Llama2-7b | 阿里巴巴·帕伊(Alibaba Pai) | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| T5-11b | jax/t5x | 可根据要求提供 |
| MPT-13B | 马赛克作曲家 | https://www.databricks.com/blog/turbocharged-training-training-timistimizing-databricks-mosaic-ai stack-stack-fp8 |
| GPT-22B | NEMO框架 | 可根据要求提供 |
| Llama2-70B | 阿里巴巴·帕伊(Alibaba Pai) | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| GPT-175B | JAX/PAXML | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/rosetta/projects/pax#h100-results |
变压器引擎已与流行的LLM框架集成在一起,例如:
我们欢迎对变压引擎的贡献!为了为变压器引擎做出贡献并提出拉力请求,请遵循贡献指南中概述的指南。