QuickStart | Installation | Benutzerhandbuch | Beispiele | FP8 -Konvergenz | Integrationen | Versionshinweise

Transformator Engine (TE) ist eine Bibliothek zur Beschleunigung von Transformatormodellen für NVIDIA-GPUs, einschließlich der Verwendung von 8-Bit-Gleitkomma-Präzision (FP8) auf dem Hopper-GPUs, um eine bessere Leistung mit geringerer Speicherauslastung sowohl in der Schulung als auch bei der Inferenz zu bieten. TE bietet eine Sammlung hochoptimierter Bausteine für beliebte Transformatorarchitekturen und eine automatische API mit gemischter Präzision, die mit Ihrem Framework-spezifischen Code nahtlos verwendet werden kann. TE enthält auch eine Agnostische C ++ - API von Framework, die in andere Deep -Learning -Bibliotheken integriert werden kann, um FP8 -Unterstützung für Transformatoren zu ermöglichen.
Wenn die Anzahl der Parameter in Transformatormodellen weiter wächst, werden das Training und die Schlussfolgerung für Architekturen wie Bert, GPT und T5 sehr Erinnerung und rechenintensiv. Die meisten tiefen Lernrahmen trainieren standardmäßig mit FP32. Dies ist jedoch nicht wichtig, um für viele tiefe Lernmodelle die volle Genauigkeit zu erreichen. Unter Verwendung von Mischprezisionstraining, das beim Training eines Modells ein einzelner Präzision (FP32) mit einem niedrigeren Präzisionsformat (EG FP16) kombiniert, führt zu signifikanten Beschleunigungen mit minimalen Unterschieden in der Genauigkeit im Vergleich zum FP32-Training. Mit der Hopper GPU Architecture wurde FP8 Precision eingeführt, was eine verbesserte Leistung gegenüber FP16 ohne Verschlechterung der Genauigkeit bietet. Obwohl alle großen Deep -Learning -Frameworks FP16 unterstützen, ist die FP8 -Unterstützung heute in Frameworks nicht nativ verfügbar.
TE befasst sich mit dem Problem der FP8 -Unterstützung, indem sie APIs bereitstellen, die in LLM -Bibliotheken (Language Language Model) integriert werden. Es bietet eine Python-API, die aus Modulen besteht, um einfach eine Transformatorschicht sowie eine Framework-Agnostic-Bibliothek in C ++ zu erstellen, einschließlich Strukturen und Kernel, die für die Unterstützung von FP8 benötigt werden. Module, die von TE bereitgestellt werden, halten die für das FP8 -Training erforderlichen Skalierungsfaktoren und andere Werte bei, wodurch ein gemischtes Präzisionstraining für Benutzer erheblich vereinfacht wird.
import torch
import transformer_engine . pytorch as te
from transformer_engine . common import recipe
# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
# Initialize model and inputs.
model = te . Linear ( in_features , out_features , bias = True )
inp = torch . randn ( hidden_size , in_features , device = "cuda" )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . E4M3 )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
out = model ( inp )
loss = out . sum ()
loss . backward () import flax
import jax
import jax . numpy as jnp
import transformer_engine . jax as te
import transformer_engine . jax . flax as te_flax
from transformer_engine . common import recipe
BATCH = 32
SEQLEN = 128
HIDDEN = 1024
# Initialize RNG and inputs.
rng = jax . random . PRNGKey ( 0 )
init_rng , data_rng = jax . random . split ( rng )
inp = jax . random . normal ( data_rng , [ BATCH , SEQLEN , HIDDEN ], jnp . float32 )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . HYBRID )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
model = te_flax . DenseGeneral ( features = HIDDEN )
def loss_fn ( params , other_vars , inp ):
out = model . apply ({ 'params' : params , ** other_vars }, inp )
return jnp . mean ( out )
# Initialize models.
variables = model . init ( init_rng , inp )
other_variables , params = flax . core . pop ( variables , 'params' )
# Construct the forward and backward function
fwd_bwd_fn = jax . value_and_grad ( loss_fn , argnums = ( 0 , 1 ))
for _ in range ( 10 ):
loss , ( param_grads , other_grads ) = fwd_bwd_fn ( params , other_variables , inp )Der schnellste Weg, um mit Transformator Engine zu beginnen, besteht darin, Docker -Bilder im NVIDIA GPU Cloud (NGC) -Katalog zu verwenden. Zum Beispiel, um den NGC Pytorch Container interaktiv zu verwenden,
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3Wobei 23.10 die Containerversion ist. Zum Beispiel 23.10 für die Veröffentlichung im Oktober 2023.
So installieren Sie die neueste stabile Version von Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stableDadurch wird automatisch festgestellt, ob unterstützte Deep -Learning -Frameworks installiert werden und die Unterstützung der Transformer -Engine für sie erstellt. Um Frameworks explizit anzugeben, setzen Sie die Umgebungsvariable nvte_framework auf eine von Kommas getrennte Liste (z. B. nvte_framework = jax, pytorch, paddle).
Alternativ kann das Paket direkt von PYPI des Transformator Engine, z. B.
pip install transformer_engine[pytorch]Um die notwendigen Python-Bindungen für die Transformator-Engine zu erhalten, müssen die erforderlichen Rahmenbedingungen in einer von Kommas getrennten Liste explizit als zusätzliche Abhängigkeiten angegeben werden (z. B. [Jax, Pytorch, Paddle]). Transformator Engine verschifft Räder für die Kernbibliothek sowie die Paddlepaddle -Erweiterungen. Quellverteilungen werden für die Erweiterungen von JAX und Pytorch versendet.
Siehe Installationshandbuch.
Transformator Engine Release V0.11.0 fügt Flashattention-2 in Pytorch für eine verbesserte Leistung zu unterstützen.
Es ist ein bekanntes Problem, dass die Flashattention-2-Kompilierung ressourcenintensiv ist und eine große Menge RAM (siehe Fehler) erfordert, was zu den Speicherfehlern während der Installation der Transformator-Engine führen kann. Bitte versuchen Sie, max_jobs = 1 in der Umgebung festzulegen, um das Problem zu umgehen.
Beachten Sie, dass NGC Pytorch 23.08+ Container Flashattention-2 enthalten.
Um die Definition und Verwendung der Aufmerksamkeitsmaske in allen drei Frameworks in der Transformator -Engine zu vereinen, hat sich die Polstermaske von der wahren Bedeutung der entsprechenden Position in der Aufmerksamkeit bis zum Ausschluss dieser Position in unserer Pytorch -Implementierung geändert. Da V1.7, folgen alle Aufmerksamkeitsmaskentypen derselben Definition, bei der wahre Bedeutung bedeutet, die entsprechende Position zu maskieren, und falsche Mittel, einschließlich dieser Position in der Aufmerksamkeitsberechnung.
Ein Beispiel für diese Änderung ist,,
# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
b, b, 0, 0, 0,
c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True, True, True, False, False,
True, True, False, False, False,
True, True, True, True, False]
# and for v1.7 onwards it should be,
[False, False, False, True, True,
False, False, True, True, True,
False, False, False, False, True]FP8 wurde ausführlich über verschiedene Modellarchitekturen und Konfigurationen hinweg getestet, und wir fanden keinen signifikanten Unterschied zwischen FP8- und BF16 -Trainingsverlustkurven. FP8 wurde auch für die Genauigkeit bei nachgeschalteten LLM -Aufgaben (z. B. Lambada und Wikitext) validiert. Im Folgenden finden Sie Beispiele für Modelle, die für die Konvergenz über verschiedene Frameworks hinweg getestet wurden.
| Modell | Rahmen | Quelle |
|---|---|---|
| T5-770m | JAX/T5X | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/projects/t5x#convergence-t-performance |
| MPT-1.3B | Mosaikkomponist | https://www.mosaicml.com/blog/coreweave-nvidia-h100-partner-1 |
| GPT-5B | JAX/PAXML | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results |
| GPT-5B | Nemo -Framework | Auf Anfrage erhältlich |
| LAMA2-7B | Alibaba Pai | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| T5-11b | JAX/T5X | Auf Anfrage erhältlich |
| MPT-13B | Mosaikkomponist | https://www.databricks.com/blog/turboarged-training-optimizing-databricks-mosaic-ai-stack-fp8 |
| GPT-22B | Nemo -Framework | Auf Anfrage erhältlich |
| LAMA2-70B | Alibaba Pai | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| GPT-175B | JAX/PAXML | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results |
Transformator Engine wurde in beliebte LLM -Frameworks integriert wie:
Wir begrüßen Beiträge zum Transformer Engine! Befolgen Sie die Richtlinien, die in der Leitfaden für den Beitrag zu transportieren.