このリポジトリは、Hippo、LSSL、Sashimi、DSS、Httyh、S4D、S4NDなど、S4に関連するモデルの公式実装と実験を提供します。
ソースコードの概要や特定の実験複製の概要を含む、これらの各モデルのプロジェクト固有の情報は、モデル/にあります。
環境をセットアップし、S4を外部コードベースに移植する:
トレーニングモデルにこのリポジトリを使用してください:
changelog.mdを参照してください
このリポジトリには、Python 3.9+およびPytorch 1.10+が必要です。 Pytorch 1.13.1までテストされています。その他のパッケージは、recumporation.txtにリストされています。ライブラリバージョンの一部、特にTorch/Torchvision/Torchaudio/TorchTextを互換性のあるものにするために、ある程度の注意が必要になる場合があります。
インストールの例:
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
S4のコア操作は、論文で説明されているコーシーとバンデルモンドのカーネルです。これらは非常に単純なマトリックス乗算です。これらの操作の素朴な実装は、関数cauchy_naiveおよびlog_vandermonde_naiveのスタンドアロンにあります。ただし、論文が説明しているように、これは現在、Pytorchで克服するためにカスタムカーネルを必要とする最適ではないメモリ使用法を持っています。
さらに2つの効率的な方法がサポートされています。コードは、これらのいずれかがインストールされているかどうかを自動的に検出し、適切なカーネルを呼び出します。
このバージョンはより速いですが、各マシン環境の手動編集が必要です。ディレクトリextensions/kernels/からpython setup.py install実行します。
このバージョンは、Pykeops Libraryによって提供されます。通常、インストールは、要件ファイルにもリストされているpip install pykeops cmake箱から出して動作します。
S4レイヤーとバリアントの自己完結型ファイルは、モジュールを呼び出すための指示を含むモデル/S4/にあります。
HippoとS4の背後にあるいくつかの概念を説明する視覚化については、ノートブック/を参照してください。
example.pyは、スタンドアロンS4ファイルをインポートするMnistとCIFAR向けの自己完結型トレーニングスクリプトです。デフォルトの設定python example.pyは、200Kパラメーターの非常に単純なS4Dモデルで、シーケンシャルCIFARで88%の精度に達します。このスクリプトは、外部リポジトリでS4バリアントを使用するための例として使用できます。
このリポジトリは、トレーニングシーケンスモデルのための非常に柔軟なフレームワークを提供することを目的としています。多くのモデルとデータセットがサポートされています。
基本的なエントリポイントはpython -m train 、または同等です
python -m train pipeline=mnist model=s4
これは、順序付けられたMNISTデータセットでS4モデルをトレーニングします。これにより、GPUに応じて1〜3分かかる1エポックの後、約90%に達するはずです。
このリポジトリを使用する例は、全体を通して文書化されています。概要については、トレーニングを参照してください。
このコードベースの重要な機能の1つは、異なるオプティマイザーハイパーパラメーターを必要とするパラメーターをサポートすることです。特に、SSMカーネルは特に
これを外部レポで実装する方法の例については、モデルのメソッドregister (s4d.pyなど)とトレーニングスクリプト(例:example.py)の関数setup_optimizer参照してください。
このリポジトリのコアトレーニングインフラストラクチャは、Hydraに基づく構成スキームを使用したPytorch-Lightningに基づいています。
メインエントリポイントはtrain.pyで、構成はconfigs/にあります。
基本的なデータセットは、MNIST、CIFAR、および音声コマンドを含む自動ダウンロードされています。データセットの作成とロードのすべてのロジックは、SRC/DataloAdersディレクトリにあります。このサブディレクトリ内のreadmeは、他のデータセットをダウンロードして整理する方法を文書化しています。
モデルはSRC/モデルで定義されています。概要については、このサブディレクトリのreadmeを参照してください。
元のS4ペーパーなどのモデル/モデルのプロジェクト固有の情報の下にある、論文からのエンドツーエンドの実験を再現する事前定義された構成が提供されます。
構成は、コマンドラインを介して簡単に変更することもできます。実験の例です
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
これは、指定された数のレイヤー、バックボーン寸法、および正規化タイプを備えたS4モデルを使用した順なMNISTタスクを使用します。
構成に関する詳細なドキュメントについては、configs/readme.mdを参照してください。
構成フレームワークを完全に理解するために、HYDRAドキュメントを読むことをお勧めします。特定の実験を開始するのに役立ちますが、問題を提出してください。
各実験は、フォームの独自のディレクトリ(Hydraによって生成)にログインします./outputs/<date>/<time>/ <time>/。チェックポイントは、このフォルダー内でここに保存され、新しいチェックポイントが作成されるたびにコンソールに印刷されます。トレーニングを再開するには、希望する.ckptファイル(pytorch lightningチェックポイントなど、 ./outputs/<date>/<time>/checkpoints/val/loss.ckpt outputs/<date>//<time>//checkpoints/val/loss.ckpt)を指すだけで、flag train.ckpt=<path>/<to>/<checkpoint>.ckptポイント> .ckptを指定します。
PTLトレーナークラスは、全体的なトレーニングループを制御し、多くの有用な事前定義されたフラグを提供します。いくつかの有用な例を以下に説明します。許容フラグの完全なリストは、PTLドキュメントとトレーナーの構成に記載されています。最も有用なオプションについては、デフォルトのトレーナーConfigs/Trainer/Default.YAMLを参照してください。
trainer.gpus=2を渡すだけで、2つのGPUでトレーニングします。
trainer.weights_summary=fullは、パラメーターカウントを使用して、モデルのすべてのレイヤーを印刷します。モデルの内部のデバッグに役立ちます。
trainer.limit_{train,val}_batches={10,0.1}トレイン(検証)のみの10バッチ(すべてのバッチの0.1分率)。すべてのデータを通過せずにトレインループをテストするのに役立ちます。
WandBでのロギングは、このリポジトリに組み込まれています。これを使用するには、 WANDB_API_KEY環境変数を設定し、configs/config.yamlのwandb.project属性を変更するだけです(または、 python -m train .... wandb.project=s4 )コマンドラインに渡します)。
wandb=nullを設定して、wandbロギングをオフにします。
generate.pyスクリプトを使用して、自動網性生成を実行できます。このスクリプトは、このコードベースを使用してモデルをトレーニングした後、2つの方法で使用できます。
より柔軟なオプションには、訓練されたPytorch Lightningモデルのチェックポイントパスが必要です。 Generation Scriptは、configs/generate.yamlで文書化されたいくつかの追加フラグを使用して、トレインスクリプトと同じ構成オプションを受け入れます。 python -m train <train flags>でトレーニングした後、
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
構成にあるフラグのいずれかは、オーバーライドできます。
注:このオプションは、 .ckptチェックポイント(トレーナーの情報を含むpytorch lightning)または.ptチェックポイント(Pytorch、これは単なるモデル状態DICT)のいずれかで使用できます。
生成の2番目のオプションは、再びフラグを再度渡す必要はなく、代わりに、実験フォルダー内のPytorch Lightningチェックポイントとともに、Hydra Experimentフォルダーの構成を読み取ります。
wikitext-103モデルチェックポイント、たとえば./checkpoints/s4-wt103.ptをダウンロードしてください。このモデルは、コマンドpython -m train experiment=lm/s4-wt103でトレーニングされました。構成から、モデルが長さ8192の受容フィールドでトレーニングされていることがわかります。
生成するには、実行します
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
これにより、長さ8192のプレフィックスを条件付けられた長さ16384のサンプルが生成されます。
SC09データセットで小さな刺身モデルをトレーニングしましょう。また、トレーニングと検証バッチの数を減らして、チェックポイントをより速く取得することもできます。
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
最初のエポックが完了した後、チェックポイントが保存される場所を示すメッセージが印刷されます。
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
オプション1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
このオプションは、モデルとデータセットを構築できるように完全な構成を再定義します。
オプション2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
このオプションには、Hydra実験フォルダーへのパスと内部の目的のチェックポイントのみが必要です。
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
このコードベースを使用している場合、またはその他の作業が価値があると判断した場合は、S4やその他の関連する論文を引用してください。
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}