QuickStart | Installation | Guide de l'utilisateur | Exemples | Convergence FP8 | Intégrations | Notes de libération

Transformateur Engine (TE) est une bibliothèque pour accélérer les modèles de transformateurs sur les GPU NVIDIA, y compris en utilisant la précision du point flottant 8 bits (FP8) sur les GPU de la trémie, pour offrir de meilleures performances avec une utilisation de la mémoire plus faible dans la formation et l'inférence. TE fournit une collection de blocs de construction hautement optimisés pour les architectures de transformateurs populaires et une API de précision automatique de précision automatique qui peut être utilisée de manière transparente avec votre code spécifique au framework. TE comprend également une API Agnostic C ++ Framework qui peut être intégrée à d'autres bibliothèques d'apprentissage en profondeur pour permettre la prise en charge de FP8 pour les transformateurs.
Alors que le nombre de paramètres dans les modèles de transformateurs continue de croître, la formation et l'inférence pour les architectures telles que Bert, GPT et T5 deviennent très mémoire et à forte intensité de calcul. La plupart des cadres d'apprentissage en profondeur s'entraînent avec FP32 par défaut. Ce n'est pas essentiel, cependant, pour atteindre la pleine précision pour de nombreux modèles d'apprentissage en profondeur. L'utilisation de la formation de précision mixte, qui combine une seule précision (FP32) avec un format de précision plus faible (par exemple FP16) lors de la formation d'un modèle, entraîne des accéléreuses significatives avec des différences minimales de précision par rapport à la formation FP32. Avec Hopper GPU Architecture, la précision FP8 a été introduite, ce qui offre des performances améliorées sur FP16 sans dégradation de précision. Bien que tous les principaux cadres d'apprentissage en profondeur prennent en charge FP16, le support FP8 n'est pas disponible nativement dans les cadres aujourd'hui.
TE aborde le problème de la prise en charge de FP8 en fournissant des API qui s'intègrent aux bibliothèques populaires du modèle de grande langue (LLM). Il fournit une API Python composée de modules pour construire facilement une couche de transformateur ainsi qu'une bibliothèque agnostique framework en C ++, y compris les structures et les noyaux nécessaires à la prise en charge FP8. Les modules fournis par TE maintiennent en interne les facteurs de mise à l'échelle et autres valeurs nécessaires à la formation FP8, simplifiant considérablement la formation de précision mixte pour les utilisateurs.
import torch
import transformer_engine . pytorch as te
from transformer_engine . common import recipe
# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
# Initialize model and inputs.
model = te . Linear ( in_features , out_features , bias = True )
inp = torch . randn ( hidden_size , in_features , device = "cuda" )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . E4M3 )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
out = model ( inp )
loss = out . sum ()
loss . backward () import flax
import jax
import jax . numpy as jnp
import transformer_engine . jax as te
import transformer_engine . jax . flax as te_flax
from transformer_engine . common import recipe
BATCH = 32
SEQLEN = 128
HIDDEN = 1024
# Initialize RNG and inputs.
rng = jax . random . PRNGKey ( 0 )
init_rng , data_rng = jax . random . split ( rng )
inp = jax . random . normal ( data_rng , [ BATCH , SEQLEN , HIDDEN ], jnp . float32 )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . HYBRID )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
model = te_flax . DenseGeneral ( features = HIDDEN )
def loss_fn ( params , other_vars , inp ):
out = model . apply ({ 'params' : params , ** other_vars }, inp )
return jnp . mean ( out )
# Initialize models.
variables = model . init ( init_rng , inp )
other_variables , params = flax . core . pop ( variables , 'params' )
# Construct the forward and backward function
fwd_bwd_fn = jax . value_and_grad ( loss_fn , argnums = ( 0 , 1 ))
for _ in range ( 10 ):
loss , ( param_grads , other_grads ) = fwd_bwd_fn ( params , other_variables , inp )Le moyen le plus rapide de commencer avec Transformer Engine est d'utiliser Docker Images sur le catalogue NVIDIA GPU Cloud (NGC). Par exemple pour utiliser le conteneur NGC Pytorch de manière interactive,
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3Où 23.10 est la version conteneur. Par exemple, 23.10 pour la sortie d'octobre 2023.
Pour installer la dernière version stable de Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stableCela détectera automatiquement si des cadres d'apprentissage en profondeur pris en charge sont installés et leur construisent la prise en charge du moteur du transformateur. Pour spécifier explicitement les frameworks, définissez la variable d'environnement nvte_framework sur une liste séparée par des virgules (par exemple nvte_framework = jax, pytorch, paddle).
Alternativement, le package peut être directement installé à partir de PYPI du Transformer Engine, par exemple
pip install transformer_engine[pytorch]Pour obtenir les liaisons Python nécessaires pour le moteur de transformateur, les cadres nécessaires doivent être explicitement spécifiés comme des dépendances supplémentaires dans une liste séparée par des virgules (par exemple [Jax, pytorch, paddle]). Transformer Engine expédie des roues pour la bibliothèque de base ainsi que les extensions de palette. Les distributions de source sont expédiées pour les extensions JAX et Pytorch.
Voir le guide d'installation.
Transformer Engine Release V0.11.0 ajoute la prise en charge de FlashAtttention-2 dans Pytorch pour améliorer les performances.
C'est un problème connu que la compilation Flashattention-2 est à forte intensité de ressources et nécessite une grande quantité de RAM (voir Bug), ce qui peut entraîner des erreurs hors mémoire lors de l'installation du moteur transformateur. Veuillez essayer de définir max_jobs = 1 dans l'environnement pour contourner le problème.
Notez que les conteneurs NGC Pytorch 23.08+ incluent Flashattention-2.
Dans un effort pour unifier la définition et l'utilisation du masque d'attention sur les trois cadres dans le moteur Transformateur, le masque de rembourrage est passé de la véritable inclusion de sens de la position correspondante dans l'attention à l'exclusion de cette position dans notre implémentation Pytorch. Depuis V1.7, tous les types de masques d'attention suivent la même définition où le vrai signifie masquer la position correspondante et les faux moyens, y compris cette position dans le calcul de l'attention.
Un exemple de ce changement est,
# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
b, b, 0, 0, 0,
c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True, True, True, False, False,
True, True, False, False, False,
True, True, True, True, False]
# and for v1.7 onwards it should be,
[False, False, False, True, True,
False, False, True, True, True,
False, False, False, False, True]FP8 a été testé largement sur différentes architectures et configurations de modèle et nous n'avons trouvé aucune différence significative entre les courbes de perte de formation FP8 et BF16. FP8 a également été validé pour la précision sur les tâches LLM en aval (par exemple Lambada et Wikitext). Vous trouverez ci-dessous des exemples de modèles testés pour la convergence dans différents cadres.
| Modèle | Cadre | Source |
|---|---|---|
| T5-770m | Jax / t5x | https://github.com/nvidia/jax-toolbox/Tree/main/Rosetta/Rosetta/projects/t5x#convergence-and-performance |
| MPT-1.3b | Compositeur en mosaïque | https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1 |
| GPT-5B | JAX / PAXML | https://github.com/nvidia/jax-toolbox/tree/main/Rosetta/Rosetta/projects/pax#h100-results |
| GPT-5B | Framework NEMO | Disponible sur demande |
| Llama2-7b | Alibaba Pai | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| T5-11B | Jax / t5x | Disponible sur demande |
| MPT-13B | Compositeur en mosaïque | https://www.databricks.com/blog/turbocharged-training-optimiing-databricks-mosaic-ai-stack-fp8 |
| GPT-22B | Framework NEMO | Disponible sur demande |
| LLAMA2-70B | Alibaba Pai | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| GPT-175B | JAX / PAXML | https://github.com/nvidia/jax-toolbox/tree/main/Rosetta/Rosetta/projects/pax#h100-results |
Transformer Engine a été intégré à des cadres LLM populaires tels que:
Nous saluons les contributions à Transformer Engine! Pour contribuer au moteur Transformer et faire des demandes de traction, suivez les directives décrites dans le guide contribution.