Este repositorio proporciona las implementaciones y experimentos oficiales para los modelos relacionados con S4, incluidos Hippo, LSSL, Sashimi, DSS, HTTYH, S4D y S4ND.
La información específica del proyecto para cada uno de estos modelos, incluida la descripción general del código fuente y las reproducciones de experimentos específicas, se puede encontrar en los modelos/.
Configuración del entorno y portar S4 a bases de código externas:
Uso de este repositorio para modelos de entrenamiento:
Ver ChangeLog.md
Este repositorio requiere Python 3.9+ y Pytorch 1.10+. Se ha probado hasta Pytorch 1.13.1. Otros paquetes se enumeran en requisitos.txt. Es posible que se necesiten algo de cuidado para que algunas de las versiones de la biblioteca sean compatibles, particularmente la antorcha/atorchvision/torthaudio/antorchtext.
Instalación de ejemplo:
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
Una operación central de S4 son los núcleos Cauchy y Vandermonde descritos en el documento. Estas son multiplicaciones de matriz muy simples; Se puede encontrar una implementación ingenua de esta operación en el independiente en la función cauchy_naive y log_vandermonde_naive . Sin embargo, como describe el documento, esto tiene un uso de memoria subóptima que actualmente requiere un núcleo personalizado para superar en Pytorch.
Se admiten dos métodos más eficientes. El código detectará automáticamente si se instala alguno de estos y llame al kernel apropiado.
Esta versión es más rápida pero requiere una compilación manual para cada entorno de la máquina. Ejecute python setup.py install desde las extensions/kernels/ .
Esta versión es proporcionada por la Biblioteca Pykeops. La instalación generalmente funciona fuera de la caja con pip install pykeops cmake que también se enumera en el archivo de requisitos.
Los archivos autónomos para la capa y variantes S4 se pueden encontrar en los modelos/S4/, que incluye instrucciones para llamar al módulo.
Vea cuadernos/ para visualizaciones que explican algunos conceptos detrás de Hippo y S4.
Ejemplo.py es un script de entrenamiento autónomo para MNIST y CIFAR que importa el archivo S4 independiente. La configuración predeterminada python example.py alcanza el 88% de precisión en CIFAR secuencial con un modelo S4D muy simple de 200K parámetros. Este script se puede usar como un ejemplo para usar variantes S4 en repositorios externos.
Este repositorio tiene como objetivo proporcionar un marco muy flexible para los modelos de secuencia de capacitación. Se admiten muchos modelos y conjuntos de datos.
El punto de entrada básico es python -m train , o de manera equivalente
python -m train pipeline=mnist model=s4
que entrena un modelo S4 en el conjunto de datos MNIST Permutado. Esto debería llegar a alrededor del 90% después de 1 época, que toma 1-3 minutos dependiendo de la GPU.
Más ejemplos de uso de este repositorio se documentan en todo momento. Vea la capacitación para una descripción general.
Una característica importante de esta base de código es admitir parámetros que requieren diferentes hiperparámetros de optimizador. En particular, el kernel SSM es particularmente sensible al
Consulte el register de métodos en el modelo (por ejemplo, S4D.PY) y la función setup_optimizer en el script de entrenamiento (por ejemplo, ejemplo.py) para ver un ejemplo de cómo implementar esto en reposteros externos.
La infraestructura de capacitación central de este repositorio se basa en Pytorch-Lightning con un esquema de configuración basado en Hydra.
El principal punto de entrada es train.py y las configuraciones se encuentran en configs/ .
Los conjuntos de datos básicos se descargan automáticamente, incluidos los comandos MNIST, CIFAR y del habla. Toda la lógica para crear y cargar conjuntos de datos se encuentra en el directorio SRC/DataLoaders. El ReadMe dentro de este subdirectorio documenta cómo descargar y organizar otros conjuntos de datos.
Los modelos se definen en SRC/modelos. Vea el ReadMe en este subdirectorio para obtener una visión general.
Se proporcionan configuraciones predefinidas que reproducen experimentos de extremo a extremo de los documentos, que se encuentran bajo información específica del proyecto en modelos/, como el documento S4 original.
Las configuraciones también se pueden modificar fácilmente a través de la línea de comando. Un ejemplo de experimento es
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
Esto utiliza la tarea MNIST permutada con un modelo S4 con un número especificado de capas, dimensión de esqueleto y tipo de normalización.
Consulte Configs/Readme.md para una documentación más detallada sobre las configuraciones.
Se recomienda leer la documentación de Hydra para comprender completamente el marco de configuración. Para obtener ayuda para lanzar experimentos específicos, presente un problema.
Cada experimento se registrará a su propio directorio (generado por Hydra) del formulario ./outputs/<date>/<time>/ /<Time>/. Los puntos de control se guardarán aquí dentro de esta carpeta e se imprimirán en la consola siempre que se cree un nuevo punto de control. Para reanudar la capacitación, simplemente apunte al archivo .ckpt deseado (un punto de control Pytorch Lightning, por ejemplo ./outputs/<date>/<time>/checkpoints/val/loss.ckpt <DATE>/<Time>/checkpoints/val/loss.ckpt) y agregue el flager train.ckpt=<path>/<to>/<checkpoint>.ckpt al comando de entrenamiento original.
La clase de entrenador PTL controla el bucle de entrenamiento general y también proporciona muchas banderas predefinidas útiles. Algunos ejemplos útiles se explican a continuación. La lista completa de indicadores permitidos se puede encontrar en la documentación PTL, así como en nuestras configuraciones de entrenador. Consulte la configuración de entrenador predeterminada Configs/Trainer/Default.yaml para ver las opciones más útiles.
Simplemente pase en trainer.gpus=2 para entrenar con 2 GPU.
trainer.weights_summary=full imprime cada capa del modelo con sus recuentos de parámetros. Útil para depurar las partes internas de modelos.
trainer.limit_{train,val}_batches={10,0.1} trenes (valida) en solo 10 lotes (0.1 fracción de todos los lotes). Útil para probar el bucle del tren sin pasar por todos los datos.
El registro con Wandb está integrado en este repositorio. Para usar esto, simplemente establezca su variable de entorno WANDB_API_KEY y cambie el atributo wandb.project de config/config.yaml (o pase en la línea de comandos, por ejemplo, python -m train .... wandb.project=s4 ).
Establezca wandb=null para apagar el registro de wandb.
La generación autorregresiva se puede realizar con el script Generate.py. Este script se puede usar de dos maneras después de entrenar un modelo que usa esta base de código.
La opción más flexible requiere la ruta del punto de control del modelo de rayo Pytorch entrenado. El script de generación acepta las mismas opciones de configuración que el script del tren, con algunos indicadores adicionales que se documentan en config/generate.yaml. Después de entrenar con python -m train <train flags> , genere con
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
Cualquiera de los indicadores que se encuentran en la configuración se pueden anular.
NOTA: Esta opción se puede usar con puntos de control .ckpt (Pytorch Lightning, que incluye información para el entrenador) o .pt de los puntos de control (Pytorch, que es solo un Model State DICT).
La segunda opción para la generación no requiere pasar las banderas de entrenamiento nuevamente, y en su lugar lee la configuración de la carpeta de experimentos HYDRA, junto con un punto de control Pytorch Lightning dentro de la carpeta del experimento.
Descargue el punto de control del modelo wikitext-103, por ejemplo, a ./checkpoints/s4-wt103.pt . Este modelo fue entrenado con el Comando python -m train experiment=lm/s4-wt103 . Tenga en cuenta que desde la configuración podemos ver que el modelo fue entrenado con un campo receptivo de longitud 8192.
Para generar, ejecutar
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
Esto genera una muestra de longitud 16384 condicionada en un prefijo de longitud 8192.
Entrenemos un pequeño modelo de sashimi en el conjunto de datos SC09. También podemos reducir el número de lotes de entrenamiento y validación para obtener un punto de control más rápido:
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
Después de que se completa la primera época, se imprime un mensaje que indica dónde se guarda el punto de control.
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
Opción 1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
Esta opción redefine la configuración completa para que se pueda construir el modelo y el conjunto de datos.
Opción 2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
Esta opción solo necesita la ruta a la carpeta del experimento HYDRA y el punto de control deseado dentro.
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
Si usa esta base de código, o de otra manera encuentra nuestro trabajo valioso, cite S4 y otros documentos relevantes.
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}