Ce référentiel fournit les implémentations et expériences officielles pour les modèles liés à S4, notamment Hippo, LSSL, Sashimi, DSS, HTTYH, S4D et S4ND.
Des informations spécifiques au projet pour chacun de ces modèles, y compris un aperçu du code source et des reproductions spécifiques de l'expérience, peuvent être trouvées sous les modèles /.
Configuration de l'environnement et portage S4 vers des bases de code externes:
Utilisation de ce référentiel pour les modèles de formation:
Voir Changelog.md
Ce référentiel nécessite Python 3.9+ et Pytorch 1.10+. Il a été testé à Pytorch 1.13.1. D'autres packages sont répertoriés dans les exigences.txt. Certains soins peuvent être nécessaires pour rendre certaines des versions de la bibliothèque compatibles, en particulier Torch / TorchVision / Torchaudio / TorchText.
Exemple d'installation:
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
Une opération centrale de S4 est les grains Cauchy et Vanderonde décrits dans l'article. Ce sont des multiplications matricielles très simples; Une implémentation naïve de ces opérations peut être trouvée dans la fonction de la fonction cauchy_naive et log_vandermonde_naive . Cependant, comme le décrit l'article, cela a une utilisation sous-optimale de la mémoire qui nécessite actuellement un noyau personnalisé pour surmonter à Pytorch.
Deux méthodes plus efficaces sont prises en charge. Le code détectera automatiquement si l'un ou l'autre de ceux-ci est installé et appelle le noyau approprié.
Cette version est plus rapide mais nécessite une compilation manuelle pour chaque environnement de machine. Exécutez python setup.py install à partir des extensions/kernels/ .
Cette version est fournie par la bibliothèque Pykeops. L'installation fonctionne généralement hors de la boîte avec pip install pykeops cmake qui sont également répertoriés dans le fichier d'exigences.
Des fichiers autonomes pour la couche S4 et des variantes peuvent être trouvés dans Models / S4 /, qui comprend des instructions pour appeler le module.
Voir les cahiers / pour les visualisations expliquant certains concepts derrière Hippo et S4.
Exemple.py est un script de formation autonome pour MNIST et CIFAR qui importe le fichier S4 autonome. Les paramètres par défaut python example.py atteint une précision de 88% sur le CIFAR séquentiel avec un modèle S4D très simple de paramètres 200K. Ce script peut être utilisé comme exemple d'utilisation des variantes S4 dans des référentiels externes.
Ce référentiel vise à fournir un cadre très flexible pour les modèles de séquence de formation. De nombreux modèles et ensembles de données sont pris en charge.
Le point d'entrée de base est python -m train , ou de manière équivalente
python -m train pipeline=mnist model=s4
qui forme un modèle S4 sur l'ensemble de données MNIST permuté. Cela devrait atteindre environ 90% après 1 époque, ce qui prend 1 à 3 minutes selon le GPU.
Plus d'exemples d'utilisation de ce référentiel sont documentés partout. Voir la formation pour un aperçu.
Une caractéristique importante de cette base de code consiste à prendre en charge les paramètres qui nécessitent différents hyperparamètres d'optimiseur. En particulier, le noyau SSM est particulièrement sensible au
Voir la méthode register dans le modèle (par exemple S4D.py) et la fonction setup_optimizer dans le script d'entraînement (par exemple Exemple.py) pour un exemple de comment implémenter cela dans des références externes.
L'infrastructure de formation principale de ce référentiel est basée sur la lumière de Pytorch avec un schéma de configuration basé sur HYDRA.
Le point d'entrée principal est train.py et les configurations se trouvent dans configs/ .
Les ensembles de données de base sont téléchargés automatiquement, y compris les commandes MNIST, CIFAR et Speech. Toutes les ensembles de données de création et de chargement se trouvent dans le répertoire SRC / DatalOaders. La lecture à l'intérieur de ce sous-répertoire documente comment télécharger et organiser d'autres ensembles de données.
Les modèles sont définis dans SRC / modèles. Voir la lecture de ce sous-répertoire pour un aperçu.
Des configurations prédéfinies reproduisant des expériences de bout en bout des articles sont fournies, trouvées sous des informations spécifiques au projet dans les modèles /, comme pour le papier S4 d'origine.
Les configurations peuvent également être facilement modifiées via la ligne de commande. Un exemple d'expérience est
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
Cela utilise la tâche MNIST permutée avec un modèle S4 avec un nombre spécifié de couches, de dimension de squelette et de type de normalisation.
Voir configs / readme.md pour une documentation plus détaillée sur les configurations.
Il est recommandé de lire la documentation HYDRA pour bien comprendre le cadre de configuration. Pour aider à lancer des expériences spécifiques, veuillez déposer un problème.
Chaque expérience sera enregistrée à son propre répertoire (généré par Hydra) du formulaire ./outputs/<date>/<time>/ <time>/. Les points de contrôle seront enregistrés ici à l'intérieur de ce dossier et imprimés sur console chaque fois qu'un nouveau point de contrôle est créé. Pour reprendre la formation, indiquez simplement le fichier .ckpt souhaité (un point de contrôle Pytorch Lightning, par exemple ./outputs/<date>/<time>/checkpoints/val/loss.ckpt /<time>/Checkpoints/val/loss.ckpt) et appelez le drapeau train.ckpt=<path>/<to>/<checkpoint>.ckpt à la commande de formation originale.
La classe PTL Trainer contrôle la boucle de formation globale et fournit également de nombreux drapeaux prédéfinis utiles. Certains exemples utiles sont expliqués ci-dessous. La liste complète des drapeaux admissibles peut être trouvée dans la documentation PTL, ainsi que dans nos configurations de formateur. Voir la configuration de configuration du formateur par défaut / Trainer / Default.yaml pour les options les plus utiles.
Passez simplement dans trainer.gpus=2 pour s'entraîner avec 2 GPU.
trainer.weights_summary=full chaque couche du modèle avec leur nombre de paramètres. Utile pour déboguer les internes de modèles.
trainer.limit_{train,val}_batches={10,0.1} trains (validés) sur seulement 10 lots (0,1 fraction de tous les lots). Utile pour tester la boucle de train sans passer par toutes les données.
La connexion avec WANDB est intégrée à ce référentiel. Pour l'utiliser, définissez simplement votre variable d'environnement WANDB_API_KEY et modifiez l'attribut wandb.project de configs / config.yaml (ou passez-le sur la ligne de commande par exemple python -m train .... wandb.project=s4 ).
Définissez wandb=null pour désactiver la journalisation WANDB.
La génération autorégressive peut être effectuée avec le script generate.py. Ce script peut être utilisé de deux manières après avoir entraîné un modèle en utilisant cette base de code.
L'option la plus flexible nécessite le chemin de point de contrôle du modèle de foudre Pytorch formé. Le script de génération accepte les mêmes options de configuration que le script de train, avec quelques indicateurs supplémentaires documentés dans configs / generate.yaml. Après l'entraînement avec python -m train <train flags> , générer avec
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
Tous les drapeaux trouvés dans la configuration peuvent être remplacés.
Remarque: Cette option peut être utilisée avec des points de contrôle .ckpt (Pytorch Lightning, qui comprend des informations pour le formateur) ou des points de contrôle .pt (Pytorch, qui n'est qu'un dict d'état modèle).
La deuxième option pour la génération ne nécessite pas de passage à nouveau dans les drapeaux d'entraînement et lit à la place la configuration du dossier de l'expérience HYDRA, ainsi qu'un point de contrôle Pytorch Lightning dans le dossier de l'expérience.
Téléchargez le point de contrôle du modèle wikitext-103, par exemple sur ./checkpoints/s4-wt103.pt . Ce modèle a été formé avec la commande python -m train experiment=lm/s4-wt103 . Notez que d'après la configuration, nous pouvons voir que le modèle a été formé avec un champ réceptif de longueur 8192.
Pour générer, courir
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
Cela génère un échantillon de longueur 16384 conditionné sur un préfixe de longueur 8192.
Formons un petit modèle de sashimi sur l'ensemble de données SC09. Nous pouvons également réduire le nombre de lots de formation et de validation pour obtenir un point de contrôle plus rapidement:
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
Une fois la première époque terminée, un message est imprimé indiquant où le point de contrôle est enregistré.
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
Option 1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
Cette option redéfinit la configuration complète afin que le modèle et l'ensemble de données puissent être construits.
Option 2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
Cette option n'a besoin que du chemin vers le dossier de l'expérience HYDRA et du point de contrôle souhaité à l'intérieur.
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
Si vous utilisez cette base de code, ou si vous avez trouvé notre travail de valeur, veuillez citer S4 et d'autres articles pertinents.
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}