이 저장소는 Hippo, LSSL, Sashimi, DSS, Httyh, S4D 및 S4nd를 포함한 S4와 관련된 모델에 대한 공식 구현 및 실험을 제공합니다.
소스 코드의 개요 및 특정 실험 재생산을 포함한 각 모델에 대한 프로젝트 별 정보는 모델에서 찾을 수 있습니다.
환경 설정 및 S4 포팅 : 외부 코드베이스 :
교육 모델 에이 저장소 사용 :
changelog.md를 참조하십시오
이 저장소에는 Python 3.9+ 및 Pytorch 1.10+가 필요합니다. Pytorch 1.13.1까지 테스트되었습니다. 다른 패키지는 요구 사항에 나열되어 있습니다 .txt. 라이브러리 버전 중 일부, 특히 Torch/Torchvision/Torchaudio/Torchtext를 만들기 위해서는 일부주의가 필요할 수 있습니다.
예제 설치 :
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
S4의 핵심 작동은 논문에 설명 된 Cauchy 및 Vandermonde 커널입니다. 이것들은 매우 간단한 매트릭스 곱셈입니다. 이러한 작업의 순진한 구현은 cauchy_naive 및 log_vandermonde_naive 기능의 독립형에서 찾을 수 있습니다. 그러나이 논문에서 설명한 바와 같이, 이것은 현재 Pytorch에서 극복하기 위해 사용자 정의 커널이 필요한 차선책의 메모리 사용이 있습니다.
두 가지 더 효율적인 방법이 지원됩니다. 코드는이 중 하나가 설치되어 있는지 자동으로 감지하고 적절한 커널을 호출합니다.
이 버전은 더 빠르지 만 각 기계 환경에 대한 수동 편집이 필요합니다. python setup.py install 디렉토리 extensions/kernels/ 에서 설치를 실행하십시오.
이 버전은 Pykeops 라이브러리에서 제공합니다. 설치는 일반적으로 pip install pykeops cmake 사용하여 요구 사항 파일에 나열되어 있습니다.
S4 레이어 및 변형에 대한 자체 포함 파일은 Models/S4/에서 찾을 수 있으며 여기에는 모듈 호출 지침이 포함되어 있습니다.
Hippo 및 S4 뒤에있는 몇 가지 개념을 설명하는 시각화는 노트북/를 참조하십시오.
예제. 기본 설정 python example.py 200k 매개 변수의 매우 간단한 S4D 모델로 순차적 CIFAR에서 88% 정확도에 도달합니다. 이 스크립트는 외부 리포지토리에서 S4 변형을 사용하는 예로 사용할 수 있습니다.
이 저장소는 훈련 시퀀스 모델을위한 매우 유연한 프레임 워크를 제공하는 것을 목표로합니다. 많은 모델과 데이터 세트가 지원됩니다.
기본 엔트리 포인트는 python -m train 또는 동등한 것입니다.
python -m train pipeline=mnist model=s4
순열 된 MNIST 데이터 세트에서 S4 모델을 훈련시킵니다. 이는 GPU에 따라 1-3 분이 걸리는 1 개의 에포크 후 약 90%가되어야합니다.
이 저장소 사용의 더 많은 예가 전체적으로 문서화되어 있습니다. 개요는 교육을 참조하십시오.
이 코드베이스의 중요한 특징 중 하나는 다른 최적화기 초 파라미터가 필요한 매개 변수를 지원하는 것입니다. 특히 SSM 커널은 특히
외부 리포지토에서이를 구현하는 방법에 대한 예제는 모델 (예 : S4D.Py)의 메소드 register (예 : S4D.Py)와 함수 setup_optimizer 참조하십시오.
이 저장소의 핵심 교육 인프라는 Hydra를 기반으로 한 구성 체계와 함께 Pytorch-lightning을 기반으로합니다.
메인 엔트리 포인트는 train.py 이며 구성은 configs/ 에서 찾을 수 있습니다.
MNIST, CIFAR 및 음성 명령을 포함한 기본 데이터 세트가 자동 다운로드됩니다. 데이터 세트를 작성하고로드하기위한 모든 논리는 SRC/Dataloaders 디렉토리에 있습니다. 이 하위 디렉토리 내부의 README는 다른 데이터 세트를 다운로드하고 구성하는 방법을 문서화합니다.
모델은 SRC/모델로 정의됩니다. 개요는이 하위 디렉토리의 readme를 참조하십시오.
사전 정의 된 구성 논문에서 엔드 투 엔드 실험을 재현하는 것은 원래 S4 용지와 같은 모델/의 프로젝트 별 정보에 따라 제공됩니다.
명령 줄을 통해 구성을 쉽게 수정할 수도 있습니다. 실험은 예입니다
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
이것은 지정된 수의 레이어, 백본 치수 및 정규화 유형의 S4 모델과 함께 구분 된 MNIST 작업을 사용합니다.
구성에 대한 자세한 내용은 configs/readme.md를 참조하십시오.
구성 프레임 워크를 완전히 이해하기 위해 Hydra 문서를 읽는 것이 좋습니다. 특정 실험을 시작하는 데 도움이 되려면 문제를 제출하십시오.
각 실험은 ./outputs/<date>/<time>/ /<time>/ 형식의 자체 디렉토리 (Hydra에 의해 생성)에 기록됩니다. 검사 점은이 폴더 내부에 여기에 저장되고 새로운 체크 포인트가 생성 될 때마다 콘솔에 인쇄됩니다. 훈련을 재개하려면 원하는 .ckpt 파일 (Pytorch Lightning Checkpoint, 예 : ./outputs/<date>/<time>/checkpoints/val/loss.ckpt train.ckpt=<path>/<to>/<checkpoint>.ckpt 추가하십시오.
PTL 트레이너 클래스는 전체 교육 루프를 제어하고 유용한 미리 정의 된 많은 플래그를 제공합니다. 유용한 몇 가지 예는 아래에 설명되어 있습니다. 허용 가능한 플래그의 전체 목록은 PTL 문서와 트레이너 구성에서 찾을 수 있습니다. 가장 유용한 옵션은 기본 트레이너 구성 구성/트레이너/Default.yaml을 참조하십시오.
트레이너를 전달하기 만하면 2 gpus로 훈련하기 위해 trainer.gpus=2 를 통과하십시오.
trainer.weights_summary=full 매개 변수 카운트로 모델의 모든 계층을 인쇄합니다. 모델의 내부를 디버깅하는 데 유용합니다.
trainer.limit_{train,val}_batches={10,0.1} 10 개 배치 (모든 배치의 0.1 분율)에만 열차 (검증). 모든 데이터를 통과하지 않고도 열차 루프를 테스트하는 데 유용합니다.
WANDB 로의 로깅은이 저장소에 내장되어 있습니다. 이것을 사용하려면 WANDB_API_KEY 환경 변수를 설정하고 configs/config.yaml의 wandb.project 속성을 변경하십시오 (또는 python -m train .... wandb.project=s4 ).
wandb=null 설정하여 Wandb 로깅을 끄십시오.
자가 회귀 생성은 generate.py 스크립트로 수행 할 수 있습니다. 이 스크립트는이 코드베이스를 사용하여 모델을 훈련 한 후 두 가지 방식으로 사용할 수 있습니다.
보다 유연한 옵션에는 훈련 된 Pytorch Lightning 모델의 체크 포인트 경로가 필요합니다. Generation Script는 Configs/Generate.yaml에 문서화 된 몇 가지 추가 플래그를 사용하여 Train 스크립트와 동일한 구성 옵션을 수용합니다. python -m train <train flags> 로 훈련 한 후
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
구성에있는 플래그 중 하나를 재정의 할 수 있습니다.
참고 :이 옵션은 .ckpt 체크 포인트 (트레이너를위한 정보가 포함 된 Pytorch Lightning) 또는 .pt 체크 포인트 (Pytorch, 단지 모델 상태 dict)와 함께 사용할 수 있습니다.
세대의 두 번째 옵션은 교육 플래그를 다시 전달할 필요가 없으며 실험 폴더 내의 Pytorch Lightning Checkpoint와 함께 Hydra Experiment 폴더의 구성을 읽습니다.
예를 들어 ./checkpoints/s4-wt103.pt 로 Wikitext-103 모델 체크 포인트를 다운로드하십시오. 이 모델은 python -m train experiment=lm/s4-wt103 명령으로 교육을 받았습니다. 구성에서 우리는 모델이 길이 8192의 수용 필드로 훈련되었음을 알 수 있습니다.
생성하려면 실행하십시오
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
이것은 길이 8192의 접두사에 조절 된 길이 16384의 샘플을 생성합니다.
SC09 데이터 세트에서 작은 사시미 모델을 훈련합시다. 또한 검사 점을 더 빨리 얻기 위해 교육 및 검증 배치 수를 줄일 수 있습니다.
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
첫 번째 시대가 완료된 후에는 체크 포인트가 저장되는 위치를 나타내는 메시지가 인쇄됩니다.
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
Option 1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
이 옵션은 모델 및 데이터 세트를 구성 할 수 있도록 전체 구성을 재정의합니다.
Option 2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
이 옵션은 Hydra Experiment 폴더와 원하는 체크 포인트로가는 경로 만 있으면됩니다.
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
이 코드베이스를 사용하거나 다른 방식으로 우리의 작업이 가치있는 것으로 판명되면 S4 및 기타 관련 논문을 인용하십시오.
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}