Quickset | Instalación | Guía del usuario | Ejemplos | Convergencia FP8 | Integraciones | Notas de lanzamiento

Transformer Engine (TE) es una biblioteca para acelerar los modelos de transformadores en las GPU NVIDIA, que incluye el uso de precisión de punto flotante de 8 bits (FP8) en GPU de la tolva, para proporcionar un mejor rendimiento con una menor utilización de la memoria tanto en entrenamiento como en inferencia. TE proporciona una colección de bloques de construcción altamente optimizados para arquitecturas de transformadores populares y una API automática de precisión mixta que se puede usar sin problemas con su código específico de marco. TE también incluye una API C ++ agnóstica marco que puede integrarse con otras bibliotecas de aprendizaje profundo para permitir el soporte de FP8 para los transformadores.
A medida que el número de parámetros en los modelos de transformadores continúa creciendo, el entrenamiento y la inferencia de arquitecturas como Bert, GPT y T5 se vuelven muy memoria y de cómputo. La mayoría de los marcos de aprendizaje profundo entrenan con FP32 por defecto. Sin embargo, esto no es esencial para lograr una precisión total para muchos modelos de aprendizaje profundo. El uso de entrenamiento de precisión mixta, que combina formato de precisión única (FP32) con un formato de menor precisión (por ejemplo, FP16) al entrenar un modelo, da como resultado aceleraciones significativas con diferencias mínimas en la precisión en comparación con el entrenamiento FP32. Con la arquitectura GPU de Hopper se introdujo la precisión FP8, que ofrece un rendimiento mejorado sobre FP16 sin degradación en precisión. Aunque todos los principales marcos de aprendizaje profundo admiten FP16, el soporte FP8 no está disponible de forma nativa en los marcos hoy en día.
TE aborda el problema del soporte de FP8 al proporcionar API que se integran con las bibliotecas populares del modelo de lenguaje grande (LLM). Proporciona una API de Python que consiste en módulos para construir fácilmente una capa de transformador, así como una biblioteca de marco-agnóstico en C ++, incluidas las estructuras y los núcleos necesarios para el soporte FP8. Los módulos proporcionados por TE mantienen internamente los factores de escala y otros valores necesarios para la capacitación de FP8, simplificando enormemente la capacitación de precisión mixta para los usuarios.
import torch
import transformer_engine . pytorch as te
from transformer_engine . common import recipe
# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
# Initialize model and inputs.
model = te . Linear ( in_features , out_features , bias = True )
inp = torch . randn ( hidden_size , in_features , device = "cuda" )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . E4M3 )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
out = model ( inp )
loss = out . sum ()
loss . backward () import flax
import jax
import jax . numpy as jnp
import transformer_engine . jax as te
import transformer_engine . jax . flax as te_flax
from transformer_engine . common import recipe
BATCH = 32
SEQLEN = 128
HIDDEN = 1024
# Initialize RNG and inputs.
rng = jax . random . PRNGKey ( 0 )
init_rng , data_rng = jax . random . split ( rng )
inp = jax . random . normal ( data_rng , [ BATCH , SEQLEN , HIDDEN ], jnp . float32 )
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe . DelayedScaling ( margin = 0 , fp8_format = recipe . Format . HYBRID )
# Enable autocasting for the forward pass
with te . fp8_autocast ( enabled = True , fp8_recipe = fp8_recipe ):
model = te_flax . DenseGeneral ( features = HIDDEN )
def loss_fn ( params , other_vars , inp ):
out = model . apply ({ 'params' : params , ** other_vars }, inp )
return jnp . mean ( out )
# Initialize models.
variables = model . init ( init_rng , inp )
other_variables , params = flax . core . pop ( variables , 'params' )
# Construct the forward and backward function
fwd_bwd_fn = jax . value_and_grad ( loss_fn , argnums = ( 0 , 1 ))
for _ in range ( 10 ):
loss , ( param_grads , other_grads ) = fwd_bwd_fn ( params , other_variables , inp )La forma más rápida de comenzar con Transformer Engine es mediante el uso de imágenes Docker en el catálogo de NVIDIA GPU Cloud (NGC). Por ejemplo, para usar el contenedor NGC Pytorch de manera interactiva,
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3Donde 23.10 es la versión del contenedor. Por ejemplo, 23.10 para el lanzamiento de octubre de 2023.
Para instalar la última versión estable de Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stableEsto detectará automáticamente si se instalan marcos de aprendizaje profundo admitidos y creará soporte del motor del transformador para ellos. Para especificar explícitamente los marcos, establezca la variable de entorno nvte_framework en una lista separada por comas (por ejemplo, nvte_framework = jax, pytorch, paddle).
Alternativamente, el paquete se puede instalar directamente desde el PYPI del motor Transformer, por ejemplo,
pip install transformer_engine[pytorch]Para obtener las fijaciones de pitón necesarias para el motor transformador, los marcos necesarios deben especificarse explícitamente como dependencias adicionales en una lista separada por comas (por ejemplo, [Jax, Pytorch, Paddle]). Transformer Engine envía ruedas para la biblioteca de núcleo, así como las extensiones de paddlepaddle. Las distribuciones de fuente se envían para las extensiones de Jax y Pytorch.
Consulte la guía de instalación.
Transformer Engine Releep V0.11.0 agrega soporte para FlashAttention-2 en Pytorch para mejorar el rendimiento.
Es un problema conocido que la compilación FlashAttention-2 es intensiva en recursos y requiere una gran cantidad de RAM (ver error), lo que puede provocar errores fuera de la memoria durante la instalación del motor Transformer. Intente configurar max_jobs = 1 en el entorno para eludir el problema.
Tenga en cuenta que los contenedores NGC Pytorch 23.08+ incluyen Flashattention-2.
En un esfuerzo por unificar la definición y el uso de la máscara de atención en los tres marcos en el motor Transformer, la máscara de relleno ha cambiado de la inclusión de significado verdadero de la posición correspondiente en atención a la exclusión de esa posición en nuestra implementación de Pytorch. Desde V1.7, todos los tipos de máscara de atención siguen la misma definición donde verdadero significa enmascarar la posición correspondiente y los medios falsos que incluyen esa posición en el cálculo de la atención.
Un ejemplo de este cambio es,
# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
b, b, 0, 0, 0,
c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True, True, True, False, False,
True, True, False, False, False,
True, True, True, True, False]
# and for v1.7 onwards it should be,
[False, False, False, True, True,
False, False, True, True, True,
False, False, False, False, True]FP8 se ha probado ampliamente en diferentes arquitecturas y configuraciones de modelos y no encontramos diferencias significativas entre las curvas de pérdida de entrenamiento FP8 y BF16. FP8 también ha sido validado para su precisión en tareas de LLM aguas abajo (por ejemplo, Lambada y Wikitext). A continuación se presentan ejemplos de modelos probados para la convergencia en diferentes marcos.
| Modelo | Estructura | Fuente |
|---|---|---|
| T5-770m | Jax/t5x | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/projects/t5x#convergence-and-performance |
| MPT-1.3B | Compositor de mosaico | https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1 |
| GPT-5B | Jax/paxml | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results |
| GPT-5B | Marco nemo | Disponible a pedido |
| Llama2-7b | Alibaba Pai | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| T5-11B | Jax/t5x | Disponible a pedido |
| MPT-13B | Compositor de mosaico | https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8 |
| GPT-22B | Marco nemo | Disponible a pedido |
| Llama2-70b | Alibaba Pai | https://mp.weixin.qq.com/s/nqt0ukxlbxyh5031zbdebq |
| GPT-175B | Jax/paxml | https://github.com/nvidia/jax-toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results |
Transformer Engine se ha integrado con marcos LLM populares como:
¡Agradecemos contribuciones al motor Transformer! Para contribuir a Transformer Engine y realizar solicitudes de extracción, siga las pautas descritas en la guía contribuyente.