Этот репозиторий предоставляет официальные реализации и эксперименты для моделей, связанных с S4, включая Hippo, LSSL, Sashimi, DSS, HTTYH, S4D и S4ND.
Информация о проекте для каждой из этих моделей, включая обзор исходного кода и конкретные экспериментальные воспроизведения, можно найти в моделях/.
Настройка среды и портирование S4 на внешние кодовые базы:
Использование этого репозитория для тренировочных моделей:
Смотрите Changelog.md
Этот репозиторий требует Python 3.9+ и Pytorch 1.10+. Он был проверен до Pytorch 1.13.1. Другие пакеты перечислены в требованиях.txt. Некоторые заботы могут потребоваться, чтобы сделать некоторые библиотечные версии совместимыми, в частности, Torch/Torchvision/Torchaudio/Torchtext.
Пример установки:
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
Основными операциями S4 являются ядра Cauchy и Vandermonde, описанные в статье. Это очень простые умножения матрицы; Наивная реализация этой операции можно найти в автономном режиме в функции cauchy_naive и log_vandermonde_naive . Однако, как описывает статья, это имеет неоптимальное использование памяти, которое в настоящее время требует от пользовательского ядра для преодоления в Pytorch.
Поддерживаются более эффективные методы. Код будет автоматически обнаружить, если один из них будет установлен, и вызовет соответствующее ядро.
Эта версия быстрее, но требует ручной компиляции для каждой машины. Запустите python setup.py install из extensions/kernels/ .
Эта версия предоставлена библиотекой Pykeops. Установка обычно работает из коробки с pip install pykeops cmake , которые также перечислены в файле требований.
Автономные файлы для слоя S4 и вариантов можно найти в моделях/S4/, которые включают инструкции по вызову модуля.
См. Записные книжки/ для визуализаций, объясняющих некоторые концепции, лежащие в основе гиппока и S4.
Пример. Настройки по умолчанию python example.py достигает 88% точности на последовательном CIFAR с очень простой моделью S4D 200K параметров. Этот скрипт можно использовать в качестве примера для использования вариантов S4 во внешних репозиториях.
Этот репозиторий направлен на обеспечение очень гибкой основы для моделей обучения последовательностей. Многие модели и наборы данных поддерживаются.
Основной точкой входа является python -m train или эквивалентно
python -m train pipeline=mnist model=s4
который обучает модель S4 на набор данных MNIST. Это должно добраться до 90% после 1 эпохи, которая занимает 1-3 минуты в зависимости от графического процессора.
Больше примеров использования этого репозитория задокументировано повсюду. Смотрите обучение для обзора.
Одной из важных функций этой кодовой базы являются поддерживающие параметры, которые требуют различных гиперпараметров оптимизатора. В частности, ядро SSM особенно чувствительно к
См. register методов в модели (например, S4D.py) и функцию setup_optimizer в обучающем скрипте (например, пример.py) для примеров того, как реализовать это во внешних репо.
Основная обучающая инфраструктура этого репозитория основана на пирожек-флининге с схемой конфигурации, основанной на HYDRA.
Основная точка входа - train.py , а конфигурации обнаруживаются в configs/ .
Основные наборы данных автоматически загружаются, включая MNIST, CIFAR и речевые команды. Вся логика для создания и загрузки наборов данных находится в каталоге SRC/DataLoaders. Readme внутри этого подкаталерительного документа, как загружать и организовать другие наборы данных.
Модели определены в SRC/моделях. Смотрите Readme в этом подкаталоге для обзора.
Предварительно определенные конфигурации, воспроизводящие сквозные эксперименты из бумаг, предоставлены, обнаруженные в разделе «Модели/» в моделях/, например, для оригинальной бумаги S4.
Конфигурации также могут быть легко изменены через командную строку. Примером эксперимента является
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
Это использует перестроченную задачу MNIST с моделью S4 с указанным количеством слоев, размерным измерением и типом нормализации.
См. Configs/readme.md для более подробной документации о конфигурациях.
Рекомендуется прочитать документацию HYDRA, чтобы полностью понять структуру конфигурации. Для запуска конкретных экспериментов, пожалуйста, подайте проблему.
Каждый эксперимент будет вошел в свой собственный каталог (сгенерированный Hydra) формы ./outputs/<date>/<time>/ <date>/<time>/. Контрольные точки будут сохранены здесь в этой папке и напечатаны в консоли всякий раз, когда создается новый контрольно -пропускной пункт. Чтобы возобновить обучение, просто укажите на желаемый файл .ckpt (контрольная точка Pytorch Lightning, например ./outputs/<date>/<time>/checkpoints/val/loss.ckpt train.ckpt=<path>/<to>/<checkpoint>.ckpt .
Класс Trainer PTL контролирует общую петлю обучения, а также предоставляет много полезных предварительно определенных флагов. Некоторые полезные примеры объясняются ниже. Полный список допустимых флагов можно найти в документации PTL, а также конфигурации тренеров. См. Конфигурации конфигурации тренера по умолчанию/Trainer/default.yaml для наиболее полезных параметров.
Просто пройдите в trainer.gpus=2 чтобы тренироваться с 2 графическими процессорами.
trainer.weights_summary=full отпечаток каждого уровня модели с количеством их параметров. Полезно для отладки внутренних пунктов моделей.
trainer.limit_{train,val}_batches={10,0.1} поезда (проверка) только на 10 партии (0,1 фракции всех партий). Полезно для тестирования цикла поезда без прохождения всех данных.
Регистрация с Wandb встроена в этот репозиторий. Чтобы использовать это, просто установите свою переменную среды WANDB_API_KEY и измените атрибут wandb.project of configs/config.yaml (или передайте ее в командную строку, например, python -m train .... wandb.project=s4 ).
Установите wandb=null , чтобы отключить журнал Wandb.
Авторегрессивная генерация может быть выполнена с помощью сценария Generate.py. Этот скрипт можно использовать двумя способами после обучения модели с использованием этой кодовой базы.
Более гибкий вариант требует пути контрольной точки обученной модели молнии. Сценарий генерации принимает те же параметры конфигурации, что и сценарий поезда, с несколькими дополнительными флагами, которые задокументированы в configs/Generate.yaml. После тренировки с python -m train <train flags> , генерируйте с
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
Любой из флагов, найденных в конфигурации, может быть переопределен.
ПРИМЕЧАНИЕ .pt Эта опция может использоваться с помощью .ckpt
Второй вариант для генерации не требует снова прохождения в обучающих флагах, а вместо этого считывает конфигурацию из папки эксперимента Hydra, а также контрольную точку Lightning Pytorch в папке эксперимента.
Загрузите контрольную точку модели Wikitext-103, например, на ./checkpoints/s4-wt103.pt . Эта модель была обучена командным python -m train experiment=lm/s4-wt103 . Обратите внимание, что из конфигурации мы видим, что модель была обучена рецептивному полю длины 8192.
Чтобы генерировать, запустить
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
Это генерирует образец длины 16384, кондиционированный на префиксе длины 8192.
Давайте тренируем небольшую модель сашими на наборе данных SC09. Мы также можем сократить количество партий обучения и валидации, чтобы быстрее получить контрольную точку:
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
После завершения первой эпохи печатается сообщение, указывающее, где сохраняется контрольная точка.
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
Вариант 1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
Эта опция переопределяет полную конфигурацию так, чтобы модель и набор данных могли быть построены.
Вариант 2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
Эта опция нуждается только в пути к папке эксперимента Hydra и желаемой контрольной точке внутри.
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
Если вы используете эту кодовую базу или иным образом нашли нашу работу ценной, пожалуйста, цитируйте S4 и другие соответствующие документы.
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}