Este repositório fornece as implementações e experimentos oficiais para modelos relacionados ao S4, incluindo Hippo, LSSL, Sashimi, DSS, Httyh, S4D e S4nd.
Informações específicas do projeto para cada um desses modelos, incluindo a visão geral do código-fonte e as reproduções específicas de experimentos, podem ser encontradas em modelos/.
Configurando o ambiente e portar S4 para bases de código externas:
Usando este repositório para modelos de treinamento:
Veja Changelog.md
Este repositório requer Python 3.9+ e Pytorch 1.10+. Foi testado até Pytorch 1.13.1. Outros pacotes estão listados no requisitos.txt. Alguns cuidados podem ser necessários para tornar algumas das versões da biblioteca compatíveis, principalmente a tocha/Torchvision/Torchaudio/TorchText.
Exemplo de instalação:
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
Uma operação central do S4 são os grãos Cauchy e Vandermonde descritos no artigo. Estas são multiplicações de matriz muito simples; Uma implementação ingênua dessa operação pode ser encontrada na função cauchy_naive e log_vandermonde_naive . No entanto, como o artigo descreve, isso possui uso de memória abaixo do ideal que atualmente exige que um kernel personalizado seja superado em Pytorch.
Dois métodos mais eficientes são suportados. O código detectará automaticamente se um deles estiver instalado e chamar o kernel apropriado.
Esta versão é mais rápida, mas requer compilação manual para cada ambiente de máquina. Execute python setup.py install nas extensions/kernels/ .
Esta versão é fornecida pela Biblioteca Pykeops. A instalação geralmente funciona fora da caixa com pip install pykeops cmake , que também estão listados no arquivo de requisitos.
Arquivos independentes para a camada S4 e variantes podem ser encontrados nos modelos/S4/, que incluem instruções para chamar o módulo.
Veja notebooks/ para visualizações explicando alguns conceitos por trás de Hippo e S4.
Exemplo.py é um script de treinamento independente para MNIST e CIFAR que importa o arquivo S4 independente. As configurações padrão python example.py Esse script pode ser usado como exemplo para usar variantes S4 em repositórios externos.
Este repositório visa fornecer uma estrutura muito flexível para modelos de sequência de treinamento. Muitos modelos e conjuntos de dados são suportados.
O ponto de entrada básico é python -m train , ou equivalentemente
python -m train pipeline=mnist model=s4
que treina um modelo S4 no conjunto de dados MNIST permutado. Isso deve chegar a cerca de 90% após 1 época, que leva de 1 a 3 minutos, dependendo da GPU.
Mais exemplos de uso deste repositório estão documentados por toda parte. Veja o treinamento para uma visão geral.
Uma característica importante desta base de código é suportar parâmetros que requerem diferentes hiperparâmetros de otimizadores. Em particular, o kernel do SSM é particularmente sensível ao
Consulte o register do método no modelo (por exemplo, s4d.py) e a função setup_optimizer no script de treinamento (por exemplo, exemplo.py) para obter exemplos de como implementar isso em repositórios externos.
A infraestrutura principal de treinamento deste repositório é baseada no pytorch-Lightning com um esquema de configuração baseado no HYDRA.
O ponto de entrada principal é train.py e as configurações são encontradas nas configs/ .
Os conjuntos de dados básicos são baixados automaticamente, incluindo comandos MNIST, CIFAR e SOEEN. Toda a lógica para criar e carregar conjuntos de dados está no diretório SRC/Dataloaders. O ReadMe dentro deste subdiretório documenta como baixar e organizar outros conjuntos de dados.
Os modelos são definidos em SRC/modelos. Veja o ReadMe neste subdiretório para uma visão geral.
São fornecidas configurações predefinidas que reproduzem experimentos de ponta a ponta dos artigos, encontrados sob informações específicas do projeto em modelos/, como para o papel S4 original.
As configurações também podem ser facilmente modificadas através da linha de comando. Um exemplo de experimento é
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
Isso usa a tarefa MNIST permutada com um modelo S4 com um número especificado de camadas, dimensão do backbone e tipo de normalização.
Consulte Configs/Readme.md para obter uma documentação mais detalhada sobre as configurações.
Recomenda -se ler a documentação do Hydra para entender completamente a estrutura de configuração. Para obter ajuda para lançar experimentos específicos, registre um problema.
Cada experimento será registrado em seu próprio diretório (gerado pela hidra) do formulário ./outputs/<date>/<time>/ <date>/<time>/. Os pontos de verificação serão salvos aqui dentro desta pasta e impressos no console sempre que um novo ponto de verificação for criado. Para retomar o treinamento, basta apontar para o arquivo .ckpt desejado (um ponto de verificação de Lightning Pytorch, por exemplo ./outputs/<date>/<time>/checkpoints/val/loss.ckpt /<time>/checkpoints/val/loss.ckpt) e apesse -se à bandeira train.ckpt=<path>/<to>/<checkpoint>.ckpt .ckpt à comando de treinamento original.
A classe de treinador PTL controla o loop geral de treinamento e também fornece muitos sinalizadores predefinidos úteis. Alguns exemplos úteis são explicados abaixo. A lista completa de sinalizadores permitidos pode ser encontrada na documentação do PTL, bem como em nossas configurações de treinador. Consulte a configuração do treinador padrão Configs/treinador/default.yaml para obter as opções mais úteis.
Basta passar no trainer.gpus=2 para treinar com 2 GPUs.
trainer.weights_summary=full todas as camadas do modelo com suas contagens de parâmetros. Útil para depurar internos de modelos.
trainer.limit_{train,val}_batches={10,0.1} trens (valida) em apenas 10 lotes (0,1 fração de todos os lotes). Útil para testar o loop do trem sem passar por todos os dados.
O registro com WandB é incorporado neste repositório. Para usá -lo, basta definir sua variável WANDB_API_KEY AIMBORAL e alterar o atributo wandb.project de configs/config.yaml (ou passe -o na linha de comando, por exemplo, o python -m train .... wandb.project=s4 ).
Definir wandb=null para desativar o log wandb.
A geração autoregressiva pode ser realizada com o script generate.py. Esse script pode ser usado de duas maneiras depois de treinar um modelo usando esta base de código.
A opção mais flexível requer o caminho do ponto de verificação do modelo de Lightning Pytorch treinado. O script de geração aceita as mesmas opções de configuração que o script de trem, com alguns sinalizadores adicionais documentados em configurações/geneate.yaml. Depois de treinar com python -m train <train flags> , gerar com
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
Qualquer uma das bandeiras encontradas na configuração pode ser substituída.
Nota: Esta opção pode ser usada com pontos de verificação .ckpt (Pytorch Lightning, que inclui informações para o treinador) ou pontos de verificação .pt (Pytorch, que é apenas um ditado de estado modelo).
A segunda opção para geração não requer a passagem nos sinalizadores de treinamento novamente e, em vez disso, lê a configuração da pasta Hydra Experiment, juntamente com um ponto de verificação de Lightning Pytorch na pasta Experimento.
Faça o download do ponto de verificação do modelo Wikitext-103, por exemplo, para ./checkpoints/s4-wt103.pt . Este modelo foi treinado com o comando python -m train experiment=lm/s4-wt103 . Observe que, na configuração, podemos ver que o modelo foi treinado com um campo receptivo de comprimento 8192.
Para gerar, execute
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
Isso gera uma amostra de comprimento 16384 condicionada em um prefixo de comprimento 8192.
Vamos treinar um pequeno modelo Sashimi no conjunto de dados SC09. Também podemos reduzir o número de lotes de treinamento e validação para obter um ponto de verificação mais rápido:
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
Após a conclusão da primeira época, uma mensagem é impressa, indicando onde o ponto de verificação é salvo.
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
Opção 1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
Esta opção redefine a configuração completa para que o modelo e o conjunto de dados possam ser construídos.
Opção 2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
Esta opção precisa apenas do caminho para a pasta de experimentos Hydra e o ponto de verificação desejado dentro.
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
Se você usar esta base de código ou achar nosso trabalho valioso, cite o S4 e outros artigos relevantes.
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}