Dieses Repository bietet die offiziellen Implementierungen und Experimente für Modelle im Zusammenhang mit S4, einschließlich Hippo, LSSL, Sashimi, DSS, Httyh, S4D und S4nd.
Projektspezifische Informationen für jedes dieser Modelle, einschließlich Überblick über den Quellcode und spezifische Experiment-Reproduktionen, finden Sie unter Models/.
Einrichten der Umgebung und Portierung von S4 in externen Codebasen:
Verwenden Sie dieses Repository für Trainingsmodelle:
Siehe ChangeLog.md
Dieses Repository erfordert Python 3.9+ und Pytorch 1.10+. Es wurde bis zu Pytorch 1.13.1 getestet. Andere Pakete sind in Anforderungen aufgeführt. Möglicherweise sind einige Sorgfalt erforderlich, um einige der Bibliotheksversionen kompatibel zu machen, insbesondere Torch/Torchvision/Torchaudio/Torchtext.
Beispiel Installation:
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
Ein Kernbetrieb von S4 sind die im Papier beschriebenen Cauchy- und Vandermonde -Kerne. Dies sind sehr einfache Matrixmultiplikationen; Eine naive Implementierung dieser Operation finden Sie im Standalone in der Funktion cauchy_naive und log_vandermonde_naive . Wie das Papier beschreibt, hat dies jedoch eine suboptimale Speicherverwendung, für die derzeit ein benutzerdefinierter Kernel in Pytorch überwunden werden muss.
Zwei effizientere Methoden werden unterstützt. Der Code erkennt automatisch, ob eine dieser diese installiert ist, und rufen Sie den entsprechenden Kernel an.
Diese Version ist schneller, erfordert jedoch eine manuelle Zusammenstellung für jede Maschinenumgebung. Führen Sie python setup.py install aus den extensions/kernels/ .
Diese Version wird von der Pykeops Library bereitgestellt. Die Installation funktioniert normalerweise nicht mit pip install pykeops cmake , die auch in der Anforderungsdatei aufgeführt sind.
In sich geschlossene Dateien für die S4-Ebene und Varianten finden Sie in Modellen/S4/, die Anweisungen zum Aufrufen des Moduls enthalten.
Siehe Notebooks/ Für Visualisierungen, die einige Konzepte hinter Hippo und S4 erklären.
Beispiel.py ist ein in sich geschlossenes Trainingsskript für MNIST und CIFAR, das die eigenständige S4-Datei importiert. Die Standardeinstellungen python example.py erreicht 88% Genauigkeit auf sequentiellem Cifar mit einem sehr einfachen S4D -Modell von 200K -Parametern. Dieses Skript kann als Beispiel für die Verwendung von S4 -Varianten in externen Repositories verwendet werden.
Dieses Repository zielt darauf ab, einen sehr flexiblen Rahmen für Trainingssequenzmodelle zu bieten. Viele Modelle und Datensätze werden unterstützt.
Der grundlegende Einstiegspunkt ist python -m train oder gleichwertig
python -m train pipeline=mnist model=s4
Das trainiert ein S4 -Modell auf dem permutierten MNIST -Datensatz. Dies sollte nach 1 Epoche rund 90% erreichen, was je nach GPU 1-3 Minuten dauert.
Weitere Beispiele für die Verwendung dieses Repositorys sind überall dokumentiert. Siehe Training für einen Überblick.
Ein wichtiges Merkmal dieser Codebasis ist die Unterstützung von Parametern, die unterschiedliche Optimierer -Hyperparameter erfordern. Insbesondere der SSM -Kernel ist besonders empfindlich gegenüber dem
Weitere Beispiele finden Sie im register im Modell (z. B. S4d.py) und im setup_optimizer der Funktionskript (z. B. Beispiel.PY), um Beispiele für die Implementierung dieser in externen Repos zu implementieren.
Die Kerntrainingsinfrastruktur dieses Repositorys basiert auf Pytorch-Lightning mit einem Konfigurationsschema, das auf Hydra basiert.
Der Haupteinstiegspunkt ist train.py und Konfigurationen finden Sie in configs/ .
Grundlegende Datensätze werden automatisch heruntergeladen, einschließlich MNIST-, CIFAR- und Sprachbefehle. Alle Logik zum Erstellen und Laden von Datensätzen finden Sie im Verzeichnis SRC/Dataloader. Das Readme in diesem Unterverzeichnis dokumentiert, wie Sie andere Datensätze herunterladen und organisieren.
Modelle sind in SRC/Modellen definiert. Eine Übersicht finden Sie in der Readme in diesem Unterverzeichnis.
Vordefinierte Konfigurationen reproduzieren End-to-End-Experimente aus den Papieren, die unter projektspezifischen Informationen in Modellen/für das ursprüngliche S4-Papier zu finden sind.
Konfigurationen können auch einfach über die Befehlszeile geändert werden. Ein Beispielexperiment ist
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
Dies verwendet die permutierte MNIST -Aufgabe mit einem S4 -Modell mit einer bestimmten Anzahl von Schichten, Backbone -Dimension und Normalisierungstyp.
Weitere detailliertere Dokumentation zu den Konfigurationen finden Sie in Configs/Readme.md.
Es wird empfohlen, die Hydra -Dokumentation zu lesen, um das Konfigurationsframework vollständig zu verstehen. Für Hilfe bei der Einführung bestimmter Experimente stellen Sie bitte ein Problem ein.
Jedes Experiment wird an seinem eigenen Verzeichnis (von Hydra generiert) des Formulars protokolliert ./outputs/<date>/<time>/ <Date>/<Time>/. Checkpoints werden hier in diesem Ordner gespeichert und in der Konsole gedruckt, wenn ein neuer Kontrollpunkt erstellt wird. Um das Training wieder aufzunehmen, verweisen Sie einfach auf die gewünschte .ckpt -Datei (ein Pytorch Lightning Checkpoint, z ./outputs/<date>/<time>/checkpoints/val/loss.ckpt <Time>/checkpoints/val/loss.ckpt) und fügen Sie den Flag train.ckpt=<path>/<to>/<checkpoint>.ckpt an den Original -Trainingsbefehl hinzu.
Die PTL Trainer-Klasse kontrolliert die Gesamttrainingsschleife und bietet auch viele nützliche vordefinierte Flags. Einige nützliche Beispiele werden unten erläutert. Die vollständige Liste der zulässigen Flags finden Sie in der PTL -Dokumentation sowie in unseren Trainerkonfigurationen. Die nützlichsten Optionen finden Sie im Standard -Trainer -Konfigurationskonfigurationen/Trainer/Standard.yaml.
Einfach in trainer.gpus=2 geben, um mit 2 GPUs zu trainieren.
trainer.weights_summary=full Nützlich für das Debuggen von Interna von Models.
trainer.limit_{train,val}_batches={10,0.1} Züge (validiert) für nur 10 Chargen (0,1 Bruchteil aller Chargen). Nützlich zum Testen der Zugschleife, ohne alle Daten zu durchlaufen.
Die Anmeldung mit Wandb ist in dieses Repository integriert. Um dies zu verwenden, setzen Sie einfach Ihre wandb.project WANDB_API_KEY und ändern Sie das Attribut von Configs/config.yaml (oder übergeben Sie es in der Befehlszeile python -m train .... wandb.project=s4 .
Stellen Sie wandb=null , um die Wandb -Protokollierung auszuschalten.
Die autoregressive Generation kann mit dem Skript generate.py durchgeführt werden. Dieses Skript kann nach dem Training eines Modells mit dieser Codebasis auf zwei Arten verwendet werden.
Die flexiblere Option erfordert den Checkpoint -Pfad des trainierten Pytorch -Blitzmodells. Das Generationsskript akzeptiert dieselben Konfigurationsoptionen wie das Zugskript mit einigen zusätzlichen Flags, die in configs/generate.yaml dokumentiert sind. Nach dem Training mit python -m train <train flags> erzeugen Sie mit
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
Jedes der in der Konfiguration gefundenen Flags kann überschrieben werden.
HINWEIS: Diese Option kann .pt mit .ckpt
Die zweite Option für die Generation erfordert nicht erneut die Trainingsflags und liest stattdessen die Konfiguration aus dem Hydra -Experiment -Ordner zusammen mit einem Pytorch -Lightning -Checkpoint im Experiment -Ordner.
Laden Sie den Wikitext-103-Modell-Checkpoint herunter, zum Beispiel auf ./checkpoints/s4-wt103.pt . Dieses Modell wurde mit dem Kommando python -m train experiment=lm/s4-wt103 trainiert. Beachten Sie, dass wir aus der Konfiguration sehen können, dass das Modell mit einem Empfangsfeld von Länge 8192 trainiert wurde.
Zu generieren, rennen
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
Dies erzeugt eine Probe mit einer Länge 16384, die auf einem Präfix der Länge 8192 konditioniert ist.
Lassen Sie uns ein kleines Sashimi -Modell im SC09 -Datensatz trainieren. Wir können auch die Anzahl der Trainings- und Validierungsstapel reduzieren, um einen Kontrollpunkt schneller zu erhalten:
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
Nach Abschluss der ersten Epoche wird eine Nachricht gedruckt, die angibt, wo der Kontrollpunkt gespeichert ist.
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
Option 1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
Diese Option definiert die vollständige Konfiguration neu, damit das Modell und der Datensatz konstruiert werden können.
Option 2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
Diese Option benötigt nur den Pfad zum Hydra Experiment -Ordner und den gewünschten Checkpoint im Inneren.
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
Wenn Sie diese Codebasis verwenden oder unsere Arbeit auf andere Weise wertvoll gefunden haben, zitieren Sie bitte S4 und andere relevante Artikel.
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}