该存储库为与S4相关的模型提供了官方实现和实验,包括河马,LSSL,Sashimi,DSS,HTTYH,S4D和S4ND。
这些模型中的每个模型的特定于项目的信息,包括源代码的概述和特定的实验复制品,可以在模型/下找到。
设置环境并将S4移植到外部代码库:
将此存储库用于培训模型:
请参阅ChangElog.md
该存储库需要Python 3.9+和Pytorch 1.10+。它已被测试到Pytorch 1.13.1。其他软件包在unignts.txt中列出。可能需要一些小心来使某些库版本兼容,尤其是火炬/火炬/火炬/火炬手。
示例安装:
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
S4的核心操作是论文中描述的Cauchy和Vandermonde内核。这些是非常简单的矩阵乘法;这些操作的幼稚实现可以在函数cauchy_naive和log_vandermonde_naive中的独立中找到。但是,正如本文所描述的那样,这具有次优的内存使用情况,目前需要一个自定义内核才能在Pytorch中克服。
支持两种更有效的方法。代码将自动检测是否已安装这些代码并调用适当的内核。
此版本更快,但需要为每个机器环境进行手动编译。从目录extensions/kernels/运行python setup.py install 。
此版本由Pykeops库提供。安装通常可以使用pip install pykeops cmake进行操作。
S4层和变体的自包含文件可以在模型/S4/中找到,其中包括调用模块的说明。
有关解释河马和S4背后的一些概念的可视化,请参见笔记本/。
example.py是MNIST和CIFAR的一个独立培训脚本,可导入独立的S4文件。默认设置python example.py具有200K参数的非常简单的S4D模型在顺序CIFAR上达到88%的精度。该脚本可以用作在外部存储库中使用S4变体的示例。
该存储库旨在为训练序列模型提供一个非常灵活的框架。支持许多模型和数据集。
基本入口点是python -m train或等效的
python -m train pipeline=mnist model=s4
它在排列的MNIST数据集上训练S4模型。这应该在1个时期后达到90%左右,取决于GPU,需要1-3分钟。
在整个过程中都记录了更多使用此存储库的示例。请参阅培训以获取概述。
该代码库的一个重要特征是支持需要不同优化器超级法的参数。特别是,SSM内核特别敏感
有关如何在外部存储库中实现此功能的示例,请参见模型(例如S4D.PY)和训练脚本(例如示例)中的函数setup_optimizer中的方法register 。
该存储库的核心训练基础架构是基于基于Hydra的配置方案的Pytorch-Lighting。
主入口点是train.py和配置在configs/中找到。
基本数据集自动下载,包括MNIST,CIFAR和语音命令。创建和加载数据集的所有逻辑都在SRC/DataLoaders目录中。此子目录文档中的读数如何下载和组织其他数据集。
模型是在SRC/模型中定义的。有关概述,请参见此子目录中的读数。
提供了从论文中复制端到端实验的预定义配置,根据模型/(例如原始S4纸)的项目特定信息找到。
配置也可以通过命令行轻松修改。一个示例实验是
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
这将使用置换的MNIST任务,其中具有指定数量的图层,骨干尺寸和归一化类型的S4模型。
有关配置的更多详细文档,请参见Configs/Readme.md。
建议阅读Hydra文档以充分了解配置框架。为了启动特定实验,请提出问题。
每个实验将记录到形式的形式的目录(由hydra生成) ./outputs/<date>/<time>/ <date>/<Time>/。检查点将在此文件夹中保存在此文件夹中,并在创建新的检查点时将其打印到主机。要恢复培训,只需指向所需的.ckpt文件(pytorch Lightning检查点,例如,例如./outputs/<date>/<time>/checkpoints/val/loss.ckpt outputs/<date>/<time> time>/checkpoints/val/loss.ckpt),然后附加flag train.ckpt=<path>/<to>/<checkpoint>.ckpt .ckpt到原始培训命令。
PTL培训师类控制整体训练循环,还提供了许多有用的预定标志。下面说明了一些有用的示例。允许标志的完整列表可在PTL文档以及我们的教练配置中找到。有关最有用的选项,请参见默认教练配置/Trainer/Default.yaml。
只需传递trainer.gpus=2即可用2 GPU训练。
trainer.weights_summary=full打印模型的每一层都具有其参数计数。可用于调试模型内部。
trainer.limit_{train,val}_batches={10,0.1}火车(验证)仅10批(所有批次的0.1分)。对于在不浏览所有数据的情况下测试火车循环时有用。
使用WandB登录位于此存储库中。为了使用此功能,只需设置您的WANDB_API_KEY环境变量,然后更改wandb.project属性/configs/config.yaml(或在命令行上传递它,例如python -m train .... wandb.project=s4 )。
设置wandb=null以关闭wandb记录。
可以使用generate.py脚本执行自回归生成。使用此代码库训练模型后,可以以两种方式使用此脚本。
更灵活的选项需要训练有素的Pytorch Lightning模型的检查点路径。该Generation脚本接受与火车脚本相同的配置选项,并在configs/generate.yaml中进行了一些其他标志。在使用python -m train <train flags>培训后,生成
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
配置中发现的任何标志都可以覆盖。
注意:此选项可以与.ckpt检查点(Pytorch Lightning,包括培训师的信息)或.pt检查点(Pytorch,Pytorch,这只是模型状态命令)。
生成的第二种选择不需要再次传递训练标志,而是从Hydra实验文件夹中读取配置以及实验文件夹中的Pytorch Lightning检查点。
下载Wikitext-103型号检查点,例如./checkpoints/s4-wt103.pt该模型通过命令python -m train experiment=lm/s4-wt103进行了训练。请注意,从配置中,我们可以看到该模型是用长度为8192的接受场训练的。
要生成,请运行
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
这会生成长度16384的样本,该样本在长度为8192的前缀上。
让我们在SC09数据集上训练一个小型生鱼片模型。我们还可以减少培训和验证批次的数量以更快地获得检查站:
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
第一个时期完成后,打印一条消息,指示保存检查点的位置。
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
选项1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
此选项重新定义完整的配置,以便可以构建模型和数据集。
选项2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
此选项只需要通往Hydra实验文件夹和所需检查点的路径。
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
如果您使用此代码库,或者以其他方式发现我们的工作很有价值,请引用S4和其他相关论文。
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}