يوفر هذا المستودع التطبيقات والتجارب الرسمية للنماذج المتعلقة بـ S4 ، بما في ذلك Hippo و LSSL و Sashimi و DSS و HTTYH و S4D و S4ND.
يمكن العثور على معلومات خاصة بالمشروع لكل من هذه النماذج ، بما في ذلك نظرة عامة على التعليمات البرمجية المصدرية وتكاثر التجربة المحددة ، ضمن النماذج/.
إعداد البيئة ونقل S4 إلى bubbases الخارجية:
باستخدام هذا المستودع لنماذج التدريب:
انظر changelog.md
يتطلب هذا المستودع Python 3.9+ و Pytorch 1.10+. تم اختباره حتى Pytorch 1.13.1. يتم سرد الحزم الأخرى في المتطلبات. قد تكون هناك حاجة إلى بعض الرعاية لجعل بعض إصدارات المكتبة متوافقة ، وخاصة Torch/Torchvision/Torchaudio/Torchtext.
مثال التثبيت:
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
العملية الأساسية لـ S4 هي نواة Cauchy و Vandermonde الموصوفة في الورقة. هذه مضاعفات مصفوفة بسيطة للغاية. يمكن العثور على تطبيق ساذج لهذه العملية في المستقل في وظيفة cauchy_naive و log_vandermonde_naive . ومع ذلك ، كما تصف الورقة ، فإن هذا له استخدام ذاكرة دون المستوى الأمثل الذي يتطلب حاليًا kernel مخصصة للتغلب عليها في Pytorch.
يتم دعم طريقتين أكثر كفاءة. سيتم اكتشاف الرمز تلقائيًا إذا تم تثبيت أي من هذه الأشياء واستدعاء النواة المناسبة.
هذا الإصدار أسرع ولكنه يتطلب تجميعًا يدويًا لكل بيئة آلة. تشغيل python setup.py install من extensions/kernels/ .
يتم توفير هذا الإصدار بواسطة مكتبة Pykeops. يعمل التثبيت عادةً خارج المربع مع pip install pykeops cmake والذي يتم سرده أيضًا في ملف المتطلبات.
يمكن العثور على الملفات المستقلة لطبقة S4 والمتغيرات في النماذج/S4/، والتي تتضمن تعليمات لاتصال الوحدة.
انظر دفاتر الملاحظات/ للاطلاع على التصورات التي تشرح بعض المفاهيم وراء فرس النهر و S4.
example.py هو نص تدريبي مستقل عن MNIST و CIFAR الذي يستورد ملف S4 المستقل. تصل الإعدادات الافتراضية python example.py إلى دقة 88 ٪ على CIFAR المتسلسل مع نموذج S4D بسيط للغاية من 200k معلمة. يمكن استخدام هذا البرنامج النصي كمثال لاستخدام المتغيرات S4 في المستودعات الخارجية.
يهدف هذا المستودع إلى توفير إطار مرن للغاية لنماذج تسلسل التدريب. يتم دعم العديد من النماذج ومجموعات البيانات.
نقطة الدخول الأساسية هي python -m train أو ما يعادلها
python -m train pipeline=mnist model=s4
الذي يدرب نموذج S4 على مجموعة بيانات MNIST المتكهبة. يجب أن يصل هذا إلى حوالي 90 ٪ بعد عصر واحد يستغرق 1-3 دقائق حسب وحدة معالجة الرسومات.
يتم توثيق المزيد من الأمثلة على استخدام هذا المستودع طوال الوقت. انظر التدريب للحصول على نظرة عامة.
تتمثل إحدى الميزات المهمة في قاعدة قاعدة الشفرة هذه في دعم المعلمات التي تتطلب أجهزة تحدٍ مختلفة. على وجه الخصوص ، نواة SSM حساسة بشكل خاص ل
راجع register الطريقة في النموذج (على سبيل المثال s4d.py) ودالة setup_optimizer في البرنامج النصي التدريبي (eg example.py) للحصول على أمثلة على كيفية تنفيذ ذلك في repos الخارجية.
تعتمد البنية التحتية التدريبية الأساسية لهذا المستودع على Lightning Pytorch مع مخطط تكوين يعتمد على Hydra.
نقطة الدخول الرئيسية هي train.py وتكوينات موجودة في configs/ .
يتم تنزيل مجموعات البيانات الأساسية ، بما في ذلك أوامر MNIST و CIFAR والكلام. جميع منطق إنشاء مجموعات البيانات وتحميلها في دليل SRC/Dataloaders. تقوم ReadMe داخل هذا الدليل الفرعي بتنزيل وتنظيم مجموعات البيانات الأخرى.
يتم تعريف النماذج في SRC/النماذج. انظر README في هذا الدليل الفرعي للحصول على نظرة عامة.
يتم توفير تكوينات محددة مسبقًا للتجارب من طرف إلى طرف من الأوراق ، والتي توجد تحت معلومات خاصة بالمشروع في النماذج/، مثل ورقة S4 الأصلية.
يمكن أيضًا تعديل التكوينات بسهولة من خلال سطر الأوامر. تجربة مثال
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
يستخدم هذا مهمة MNIST المتكهبة مع نموذج S4 مع عدد محدد من الطبقات ، والبعد العمود الفقري ، ونوع التطبيع.
راجع التكوينات/readMe.md للحصول على المزيد من الوثائق التفصيلية حول التكوينات.
يوصى بقراءة وثائق Hydra لفهم إطار التكوين تمامًا. للمساعدة في إطلاق تجارب محددة ، يرجى تقديم مشكلة.
سيتم تسجيل كل تجربة إلى دليلها الخاص (الذي تم إنشاؤه بواسطة Hydra) من النموذج ./outputs/<date>/<time>/ <time>/. سيتم حفظ نقاط التفتيش هنا داخل هذا المجلد وطباعتها إلى وحدة التحكم كلما تم إنشاء نقطة تفتيش جديدة. لاستئناف التدريب ، ما عليك سوى الإشارة إلى ملف .ckpt المطلوب (نقطة تفتيش Lightning Pytorch ، على سبيل المثال ./outputs/<date>/<time>/checkpoints/val/loss.ckpt <time>/checkpoints/val/loss.ckpt) وإلحاق train.ckpt=<path>/<to>/<checkpoint>.ckpt <ccheckpoint>.
تتحكم فئة PTL Trainer في حلقة التدريب الشاملة وتوفر أيضًا العديد من الأعلام المحددة مسبقًا المفيدة. يتم شرح بعض الأمثلة المفيدة أدناه. يمكن العثور على القائمة الكاملة للأعلام المسموح بها في وثائق PTL ، بالإضافة إلى تكوينات المدرب لدينا. راجع تكوين تكوين Config/Trainer/Default.yaml المدرب الافتراضي للحصول على الخيارات الأكثر فائدة.
ما عليك سوى المرور في trainer.gpus=2 للتدريب مع 2 وحدات معالجة الرسومات.
trainer.weights_summary=full كل طبقة من الطراز مع تعداد المعلمات الخاص بهم. مفيد لتصحيح الأخطاء الداخلية للنماذج.
trainer.limit_{train,val}_batches={10,0.1} القطارات (التحقق من صحة) على 10 مجموعات فقط (0.1 جزء من جميع الدُفعات). مفيد لاختبار حلقة القطار دون المرور بجميع البيانات.
تم بناء تسجيل مع WANDB في هذا المستودع. من أجل استخدام هذا ، ما عليك سوى تعيين متغير بيئة WANDB_API_KEY ، وتغيير سمة wandb.project من التكوينات/config.yaml (أو تمريرها على سطر الأوامر على سبيل المثال python -m train .... wandb.project=s4 ).
تعيين wandb=null لإيقاف تسجيل WANDB.
يمكن تنفيذ توليد الانحدار التلقائي باستخدام البرنامج النصي cender.py. يمكن استخدام هذا البرنامج النصي بطريقتين بعد تدريب نموذج باستخدام قاعدة البيانات هذه.
يتطلب الخيار الأكثر مرونة مسار نقطة التفتيش لطراز Lightning Pytorch المدربة. يقبل برنامج Generation Script نفس خيارات التكوين مثل البرنامج النصي للقطار ، مع بعض الأعلام الإضافية التي تم توثيقها في configs/enderate.yaml. بعد التدريب مع python -m train <train flags> ، قم بالتوليد مع
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
يمكن تجاوز أي من الأعلام الموجودة في التكوين.
ملاحظة: يمكن استخدام هذا الخيار مع نقاط التفتيش .ckpt (Lightning Pytorch ، والتي تتضمن معلومات للمدرب) أو .pt
لا يتطلب الخيار الثاني لتوليد تمرير أعلام التدريب مرة أخرى ، وبدلاً من ذلك يقرأ التكوين من مجلد تجربة Hydra ، إلى جانب نقطة تفتيش Lightning Pytorch داخل مجلد التجربة.
قم بتنزيل نقطة تفتيش نموذج Wikitext-103 ، على سبيل المثال إلى ./checkpoints/s4-wt103.pt . تم تدريب هذا النموذج مع python -m train experiment=lm/s4-wt103 . لاحظ أنه من التكوين ، يمكننا أن نرى أن النموذج قد تم تدريبه مع حقل تقبلي بطول 8192.
لتوليد ، تشغيل
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
هذا يولد عينة من الطول 16384 مشروط على بادئة الطول 8192.
دعنا ندرب نموذج Sashimi صغير على مجموعة بيانات SC09. يمكننا أيضًا تقليل عدد مجموعات التدريب والتحقق للحصول على نقطة تفتيش أسرع:
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
بعد اكتمال أول عصر ، تتم طباعة رسالة تشير إلى مكان حفظ نقطة التفتيش.
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
الخيار 1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
يعيد هذا الخيار تعريف التكوين الكامل بحيث يمكن إنشاء النموذج ومجموعة البيانات.
الخيار 2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
يحتاج هذا الخيار فقط إلى المسار إلى مجلد تجربة Hydra ونقطة التفتيش المطلوبة في الداخل.
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
إذا كنت تستخدم قاعدة بيانات الكود هذه ، أو وجدت عملنا ذا قيمة أخرى ، فيرجى الاستشهاد بـ S4 والأوراق الأخرى ذات الصلة.
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}