該存儲庫為與S4相關的模型提供了官方實現和實驗,包括河馬,LSSL,Sashimi,DSS,HTTYH,S4D和S4ND。
這些模型中的每個模型的特定於項目的信息,包括源代碼的概述和特定的實驗複製品,可以在模型/下找到。
設置環境並將S4移植到外部代碼庫:
將此存儲庫用於培訓模型:
請參閱ChangElog.md
該存儲庫需要Python 3.9+和Pytorch 1.10+。它已被測試到Pytorch 1.13.1。其他軟件包在unignts.txt中列出。可能需要一些小心來使某些庫版本兼容,尤其是火炬/火炬/火炬/火炬手。
示例安裝:
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
S4的核心操作是論文中描述的Cauchy和Vandermonde內核。這些是非常簡單的矩陣乘法;這些操作的幼稚實現可以在函數cauchy_naive和log_vandermonde_naive中的獨立中找到。但是,正如本文所描述的那樣,這具有次優的內存使用情況,目前需要一個自定義內核才能在Pytorch中克服。
支持兩種更有效的方法。代碼將自動檢測是否已安裝這些代碼並調用適當的內核。
此版本更快,但需要為每個機器環境進行手動編譯。從目錄extensions/kernels/運行python setup.py install 。
此版本由Pykeops庫提供。安裝通常可以使用pip install pykeops cmake進行操作。
S4層和變體的自包含文件可以在模型/S4/中找到,其中包括調用模塊的說明。
有關解釋河馬和S4背後的一些概念的可視化,請參見筆記本/。
example.py是MNIST和CIFAR的一個獨立培訓腳本,可導入獨立的S4文件。默認設置python example.py具有200K參數的非常簡單的S4D模型在順序CIFAR上達到88%的精度。該腳本可以用作在外部存儲庫中使用S4變體的示例。
該存儲庫旨在為訓練序列模型提供一個非常靈活的框架。支持許多模型和數據集。
基本入口點是python -m train或等效的
python -m train pipeline=mnist model=s4
它在排列的MNIST數據集上訓練S4模型。這應該在1個時期後達到90%左右,取決於GPU,需要1-3分鐘。
在整個過程中都記錄了更多使用此存儲庫的示例。請參閱培訓以獲取概述。
該代碼庫的一個重要特徵是支持需要不同優化器超級法的參數。特別是,SSM內核特別敏感
有關如何在外部存儲庫中實現此功能的示例,請參見模型(例如S4D.PY)和訓練腳本(例如示例)中的函數setup_optimizer中的方法register 。
該存儲庫的核心訓練基礎架構是基於基於Hydra的配置方案的Pytorch-Lighting。
主入口點是train.py和配置在configs/中找到。
基本數據集自動下載,包括MNIST,CIFAR和語音命令。創建和加載數據集的所有邏輯都在SRC/DataLoaders目錄中。此子目錄文檔中的讀數如何下載和組織其他數據集。
模型是在SRC/模型中定義的。有關概述,請參見此子目錄中的讀數。
提供了從論文中復制端到端實驗的預定義配置,根據模型/(例如原始S4紙)的項目特定信息找到。
配置也可以通過命令行輕鬆修改。一個示例實驗是
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
這將使用置換的MNIST任務,其中具有指定數量的圖層,骨幹尺寸和歸一化類型的S4模型。
有關配置的更多詳細文檔,請參見Configs/Readme.md。
建議閱讀Hydra文檔以充分了解配置框架。為了啟動特定實驗,請提出問題。
每個實驗將記錄到形式的形式的目錄(由hydra生成) ./outputs/<date>/<time>/ <date>/<Time>/。檢查點將在此文件夾中保存在此文件夾中,並在創建新的檢查點時將其打印到主機。要恢復培訓,只需指向所需的.ckpt文件(pytorch Lightning檢查點,例如,例如./outputs/<date>/<time>/checkpoints/val/loss.ckpt outputs/<date>/<time> time>/checkpoints/val/loss.ckpt),然後附加flag train.ckpt=<path>/<to>/<checkpoint>.ckpt .ckpt到原始培訓命令。
PTL培訓師類控制整體訓練循環,還提供了許多有用的預定標誌。下面說明了一些有用的示例。允許標誌的完整列表可在PTL文檔以及我們的教練配置中找到。有關最有用的選項,請參見默認教練配置/Trainer/Default.yaml。
只需傳遞trainer.gpus=2即可用2 GPU訓練。
trainer.weights_summary=full打印模型的每一層都具有其參數計數。可用於調試模型內部。
trainer.limit_{train,val}_batches={10,0.1}火車(驗證)僅10批(所有批次的0.1分)。對於在不瀏覽所有數據的情況下測試火車循環時有用。
使用WandB登錄位於此存儲庫中。為了使用此功能,只需設置您的WANDB_API_KEY環境變量,然後更改wandb.project屬性/configs/config.yaml(或在命令行上傳遞它,例如python -m train .... wandb.project=s4 )。
設置wandb=null以關閉wandb記錄。
可以使用generate.py腳本執行自回歸生成。使用此代碼庫訓練模型後,可以以兩種方式使用此腳本。
更靈活的選項需要訓練有素的Pytorch Lightning模型的檢查點路徑。該Generation腳本接受與火車腳本相同的配置選項,並在configs/generate.yaml中進行了一些其他標誌。在使用python -m train <train flags>培訓後,生成
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
配置中發現的任何標誌都可以覆蓋。
注意:此選項可以與.ckpt檢查點(Pytorch Lightning,包括培訓師的信息)或.pt檢查點(Pytorch,Pytorch,這只是模型狀態命令)。
生成的第二種選擇不需要再次傳遞訓練標誌,而是從Hydra實驗文件夾中讀取配置以及實驗文件夾中的Pytorch Lightning檢查點。
下載Wikitext-103型號檢查點,例如./checkpoints/s4-wt103.pt該模型通過命令python -m train experiment=lm/s4-wt103進行了訓練。請注意,從配置中,我們可以看到該模型是用長度為8192的接受場訓練的。
要生成,請運行
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
這會生成長度16384的樣本,該樣本在長度為8192的前綴上。
讓我們在SC09數據集上訓練一個小型生魚片模型。我們還可以減少培訓和驗證批次的數量以更快地獲得檢查站:
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
第一個時期完成後,打印一條消息,指示保存檢查點的位置。
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
選項1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
此選項重新定義完整的配置,以便可以構建模型和數據集。
選項2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
此選項只需要通往Hydra實驗文件夾和所需檢查點的路徑。
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
如果您使用此代碼庫,或者以其他方式發現我們的工作很有價值,請引用S4和其他相關論文。
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}