QuickStart |インストール|ユーザーガイド|例| FP8収束|統合|ノートをリリースします

Transformer Engine(TE)は、NVIDIA GPUのトランスモデルを加速するライブラリであり、ホッパーGPUで8ビットフローティングポイント(FP8)精度を使用して、トレーニングと推論の両方でより低いメモリ利用によりパフォーマンスを向上させます。 TEは、人気のある変圧器アーキテクチャ用の高度に最適化されたビルディングブロックのコレクションと、フレームワーク固有のコードでシームレスに使用できる自動混合精度のようなAPIを提供します。 TEには、他の深い学習ライブラリと統合して、トランスのFP8サポートを有効にすることができるフレームワークの不可知論C ++ APIも含まれています。
変圧器モデルのパラメーターの数が成長し続けるにつれて、BERT、GPT、T5などのアーキテクチャのトレーニングと推論は非常にメモリと計算集約型になります。ほとんどのディープラーニングフレームワークは、デフォルトでFP32でトレーニングします。ただし、これは多くの深い学習モデルの完全な精度を達成するために不可欠ではありません。モデルのトレーニング時に単一精度(FP32)と低精度(FP16)形式の低い形式を組み合わせた混合精度トレーニングを使用すると、FP32トレーニングと比較して精度が最小限の大きなスピードアップをもたらします。 Hopper GPU Architecture FP8 Precisionが導入され、精度が低下せずにFP16よりもパフォーマンスが向上します。すべての主要なディープラーニングフレームワークはFP16をサポートしていますが、FP8サポートは今日のフレームワークでネイティブに利用できません。
TEは、一般的な大手言語モデル(LLM)ライブラリと統合するAPIを提供することにより、FP8サポートの問題に対処します。これは、FP8サポートに必要な構造体とカーネルを含むC ++のフレームワークと存在するライブラリを簡単に構築するモジュールで構成されるPython APIを提供します。 TEが提供するモジュールは、FP8トレーニングに必要なスケーリング要因やその他の価値を内部的に維持し、ユーザー向けの混合精度トレーニングを大幅に簡素化します。
import torch
import transformer_engine . pytorch as te
from transformer_engine . common import recipe
# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
# Initialize model and inputs.
model = te . Linear ( in_features , out_features , bias = True )
inp = torch . randn ( hidden_size , in_features , device = "cuda" )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . E4M3 )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
out = model ( inp )
loss = out . sum ()
loss . backward () import flax
import jax
import jax . numpy as jnp
import transformer_engine . jax as te
import transformer_engine . jax . flax as te_flax
from transformer_engine . common import recipe
BATCH = 32
SEQLEN = 128
HIDDEN = 1024
# Initialize RNG and inputs.
rng = jax . random . PRNGKey ( 0 )
init_rng , data_rng = jax . random . split ( rng )
inp = jax . random . normal ( data_rng , [ BATCH , SEQLEN , HIDDEN ], jnp . float32 )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . HYBRID )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
model = te_flax . DenseGeneral ( features = HIDDEN )
def loss_fn ( params , other_vars , inp ):
out = model . apply ({ 'params' : params , ** other_vars }, inp )
return jnp . mean ( out )
# Initialize models.
variables = model . init ( init_rng , inp )
other_variables , params = flax . core . pop ( variables , 'params' )
# Construct the forward and backward function
fwd_bwd_fn = jax . value_and_grad ( loss_fn , argnums = ( 0 , 1 ))
for _ in range ( 10 ):
loss , ( param_grads , other_grads ) = fwd_bwd_fn ( params , other_variables , inp )トランスエンジンを始める最も簡単な方法は、NVIDIA GPU Cloud(NGC)カタログでDocker画像を使用することです。たとえば、NGC Pytorchコンテナをインタラクティブに使用するには、
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3ここで、23.10はコンテナバージョンです。たとえば、2023年10月のリリースで23.10。
トランスエンジンの最新バージョンをインストールするには、
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stableこれにより、サポートされているディープラーニングフレームワークがインストールされ、トランスエンジンのサポートが組み込まれているかどうかが自動的に検出されます。フレームワークを明示的に指定するには、環境変数nvte_frameworkをコンマ分離リスト(nvte_framework = jax、pytorch、paddleなど)に設定します。
または、パッケージはトランスエンジンのピピから直接インストールできます。
pip install transformer_engine[pytorch]トランスエンジンに必要なPythonバインディングを取得するには、必要なフレームワークを、コンマ分離リスト([Jax、Pytorch、Paddle]など)の追加依存関係として明示的に指定する必要があります。トランスエンジンは、コアライブラリとパドルパドル拡張機能にホイールを搭載しています。ソース分布は、JAXおよびPytorch拡張機能のために出荷されます。
インストールガイドを参照してください。
トランスエンジンリリースv0.11.0は、パフォーマンスを向上させるために、PytorchのFlashattention-2のサポートを追加します。
Flashattention-2コンピレーションはリソース集約型であり、大量のRAM(バグを参照)を必要とすることは既知の問題です。これは、変圧器エンジンの設置中にメモリエラーが発生する可能性があります。問題を回避するために、環境でmax_jobs = 1を設定してみてください。
NGC Pytorch 23.08+コンテナにはFlashattention-2が含まれていることに注意してください。
トランスエンジンの3つのフレームワークすべてにわたる注意マスクの定義と使用を統合するために、パディングマスクは、Pytorchの実装におけるその位置を排除するために、対応する位置の真の意味を含めることから変化しました。 v1.7以来、すべての注意マスクのタイプは同じ定義に従います。これは、真の意味が対応する位置をマスクすることと、注意計算におけるその位置を含む虚偽とを意味します。
この変更の例は、
# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
b, b, 0, 0, 0,
c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True, True, True, False, False,
True, True, False, False, False,
True, True, True, True, False]
# and for v1.7 onwards it should be,
[False, False, False, True, True,
False, False, True, True, True,
False, False, False, False, True]FP8は、さまざまなモデルアーキテクチャと構成で広範囲にテストされており、FP8とBF16トレーニング損失曲線の間に有意差は見つかりませんでした。 FP8は、下流のLLMタスク(LambadaやWikitextなど)の精度についても検証されています。以下は、さまざまなフレームワーク間の収束についてテストされたモデルの例です。
| モデル | フレームワーク | ソース |
|---|---|---|
| T5-770M | Jax/T5x | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/projects/t5x#convergence-and-performance |
| MPT-1.3B | モザイク作曲家 | https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1 |
| GPT-5B | jax/paxml | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results |
| GPT-5B | NEMOフレームワーク | リクエストに応じて利用できます |
| llama2-7b | アリババパイ | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| T5-11b | Jax/T5x | リクエストに応じて利用できます |
| MPT-13B | モザイク作曲家 | https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8 |
| GPT-22B | NEMOフレームワーク | リクエストに応じて利用できます |
| llama2-70b | アリババパイ | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| GPT-175B | jax/paxml | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results |
トランスエンジンは、次のような一般的なLLMフレームワークと統合されています。
トランスエンジンへの貢献を歓迎します!トランスエンジンに貢献し、プルリクエストを行うには、Contributing.RSTガイドに概説されているガイドラインに従ってください。