ที่เก็บนี้ให้การใช้งานอย่างเป็นทางการและการทดลองสำหรับแบบจำลองที่เกี่ยวข้องกับ S4 รวมถึง HIPPO, LSSL, Sashimi, DSS, HTTYH, S4D และ S4ND
ข้อมูลเฉพาะโครงการสำหรับแต่ละรุ่นเหล่านี้รวมถึงภาพรวมของซอร์สโค้ดและการทำซ้ำการทดลองเฉพาะสามารถพบได้ภายใต้โมเดล/
การตั้งค่าสภาพแวดล้อมและการพอร์ต S4 ไปยังรหัสฐานภายนอก:
การใช้ที่เก็บนี้สำหรับรูปแบบการฝึกอบรม:
ดู changelog.md
ที่เก็บนี้ต้องใช้ Python 3.9+ และ Pytorch 1.10+ มันได้รับการทดสอบถึง Pytorch 1.13.1 แพ็คเกจอื่น ๆ แสดงอยู่ในข้อกำหนด. txt อาจจำเป็นต้องใช้ความระมัดระวังเพื่อให้ห้องสมุดบางรุ่นเข้ากันได้โดยเฉพาะอย่างยิ่งคบเพลิง/Torchvision/Torchaudio/Torchtext
การติดตั้งตัวอย่าง:
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
การดำเนินการหลักของ S4 คือเมล็ด Cauchy และ Vandermonde ที่อธิบายไว้ในกระดาษ นี่คือการคูณเมทริกซ์ที่ง่ายมาก การใช้งานที่ไร้เดียงสาของการดำเนินการเหล่านี้สามารถพบได้ในสแตนด์อโลนในฟังก์ชั่น cauchy_naive และ log_vandermonde_naive อย่างไรก็ตามตามที่กระดาษอธิบายสิ่งนี้มีการใช้หน่วยความจำที่ไม่ดีซึ่งปัจจุบันต้องใช้เคอร์เนลที่กำหนดเองเพื่อเอาชนะใน Pytorch
รองรับสองวิธีที่มีประสิทธิภาพมากขึ้น รหัสจะตรวจจับโดยอัตโนมัติหากมีการติดตั้งเหล่านี้และเรียกเคอร์เนลที่เหมาะสม
รุ่นนี้เร็วขึ้น แต่ต้องใช้การรวบรวมด้วยตนเองสำหรับแต่ละสภาพแวดล้อมของเครื่อง รัน python setup.py install จาก extensions/kernels/
รุ่นนี้จัดทำโดยไลบรารี Pykeops การติดตั้งมักจะทำงานนอกกรอบด้วย pip install pykeops cmake ซึ่งแสดงอยู่ในไฟล์ข้อกำหนด
ไฟล์ที่มีอยู่ในตัวเองสำหรับเลเยอร์ S4 และตัวแปรสามารถพบได้ใน Models/S4/ซึ่งรวมถึงคำแนะนำสำหรับการเรียกโมดูล
ดูสมุดบันทึก/ สำหรับการสร้างภาพข้อมูลอธิบายแนวคิดบางอย่างที่อยู่เบื้องหลัง Hippo และ S4
example.py เป็นสคริปต์การฝึกอบรมที่มีอยู่ในตัวเองสำหรับ MNIST และ CIFAR ที่นำเข้าไฟล์ S4 แบบสแตนด์อโลน การตั้งค่าเริ่มต้น python example.py มีความแม่นยำ 88% ใน CIFAR ตามลำดับด้วยรุ่น S4D ที่ง่ายมากของพารามิเตอร์ 200K สคริปต์นี้สามารถใช้เป็นตัวอย่างสำหรับการใช้ตัวแปร S4 ในที่เก็บภายนอก
ที่เก็บนี้มีวัตถุประสงค์เพื่อให้กรอบการทำงานที่ยืดหยุ่นมากสำหรับแบบจำลองลำดับการฝึกอบรม รองรับรุ่นและชุดข้อมูลจำนวนมาก
จุดเริ่มต้นพื้นฐานคือ python -m train หรือเทียบเท่า
python -m train pipeline=mnist model=s4
ซึ่งฝึกอบรมโมเดล S4 ในชุดข้อมูล MNIST ที่ผ่านการรับรอง สิ่งนี้ควรไปถึงประมาณ 90% หลังจาก 1 ยุคซึ่งใช้เวลา 1-3 นาทีขึ้นอยู่กับ GPU
ตัวอย่างเพิ่มเติมของการใช้พื้นที่เก็บข้อมูลนี้มีการบันทึกไว้ตลอด ดูการฝึกอบรมสำหรับภาพรวม
คุณลักษณะที่สำคัญอย่างหนึ่งของ codebase นี้คือพารามิเตอร์ที่ต้องใช้พารามิเตอร์เครื่องมือเพิ่มประสิทธิภาพที่แตกต่างกัน โดยเฉพาะอย่างยิ่งเคอร์เนล SSM นั้นมีความอ่อนไหวต่อ
ดูวิธี register ในโมเดล (เช่น S4D.PY) และฟังก์ชั่น setup_optimizer ในสคริปต์การฝึกอบรม (เช่นตัวอย่าง. py) สำหรับตัวอย่างของวิธีการใช้งานนี้ใน repos ภายนอก
โครงสร้างพื้นฐานการฝึกอบรมหลักของพื้นที่เก็บข้อมูลนี้ขึ้นอยู่กับ Pytorch-Lightning พร้อมรูปแบบการกำหนดค่าตามไฮดรา
จุดเริ่มต้นหลักคือ train.py และการกำหนดค่าพบได้ใน configs/
ชุดข้อมูลพื้นฐานคือการโหลดอัตโนมัติรวมถึง MNIST, CIFAR และคำสั่งคำพูด ตรรกะทั้งหมดสำหรับการสร้างและการโหลดชุดข้อมูลอยู่ในไดเรกทอรี SRC/Dataloaders readMe ภายในเอกสารไดเรกทอรีย่อยนี้วิธีดาวน์โหลดและจัดระเบียบชุดข้อมูลอื่น ๆ
แบบจำลองถูกกำหนดไว้ใน SRC/รุ่น ดู readMe ในไดเรกทอรีย่อยนี้สำหรับภาพรวม
การกำหนดค่าการทดลองแบบ end-to-end ที่กำหนดไว้ล่วงหน้าจากเอกสารมีให้ที่พบภายใต้ข้อมูลเฉพาะโครงการในรูปแบบ/เช่นสำหรับกระดาษ S4 ต้นฉบับ
การกำหนดค่าสามารถแก้ไขได้อย่างง่ายดายผ่านบรรทัดคำสั่ง ตัวอย่างการทดลองคือ
python -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null
สิ่งนี้ใช้งาน MNIST ที่ได้รับการรับรองด้วยรุ่น S4 ที่มีจำนวนเลเยอร์ที่ระบุมิติกระดูกสันหลังและประเภทการทำให้เป็นมาตรฐาน
ดู configs/readme.md สำหรับเอกสารรายละเอียดเพิ่มเติมเกี่ยวกับการกำหนดค่า
ขอแนะนำให้อ่านเอกสาร Hydra เพื่อทำความเข้าใจกรอบการกำหนดค่าอย่างเต็มที่ สำหรับความช่วยเหลือในการเปิดตัวการทดลองเฉพาะโปรดยื่นปัญหา
การทดลองแต่ละครั้งจะถูกบันทึกไปยังไดเรกทอรีของตัวเอง (สร้างโดย Hydra) ของแบบฟอร์ม ./outputs/<date>/<time>/ /<date>/<time>/ จุดตรวจจะถูกบันทึกไว้ที่นี่ภายในโฟลเดอร์นี้และพิมพ์ไปยังคอนโซลเมื่อใดก็ตามที่มีการสร้างจุดตรวจสอบใหม่ ในการฝึกอบรมต่อไปเพียงแค่ชี้ไปที่ไฟล์ .ckpt ที่ต้องการ (จุดตรวจ Lightning Pytorch เช่น ./outputs/<date>/<time>/checkpoints/val/loss.ckpt train.ckpt=<path>/<to>/<checkpoint>.ckpt
คลาส Trainer PTL ควบคุมการฝึกอบรมโดยรวมและยังมีธงที่กำหนดไว้ล่วงหน้าที่มีประโยชน์มากมาย ตัวอย่างที่มีประโยชน์บางอย่างอธิบายไว้ด้านล่าง รายการทั้งหมดของธงที่อนุญาตสามารถพบได้ในเอกสาร PTL รวมถึงการกำหนดค่าเทรนเนอร์ของเรา ดูการกำหนดค่าการกำหนดค่าเทรนเนอร์เริ่มต้น/เทรนเนอร์/default.yaml สำหรับตัวเลือกที่มีประโยชน์มากที่สุด
เพียงแค่ผ่านใน trainer.gpus=2 เพื่อฝึกด้วย 2 GPU
trainer.weights_summary=full รูปแบบทุกเลเยอร์ของโมเดลด้วยการนับพารามิเตอร์ของพวกเขา มีประโยชน์สำหรับการดีบักภายในของโมเดล
trainer.limit_{train,val}_batches={10,0.1} รถไฟ (ตรวจสอบ) เพียง 10 ชุด (0.1 ส่วนของแบทช์ทั้งหมด) มีประโยชน์สำหรับการทดสอบลูปรถไฟโดยไม่ต้องผ่านข้อมูลทั้งหมด
การเข้าสู่ระบบด้วย Wandb ถูกสร้างขึ้นในที่เก็บนี้ ในการใช้สิ่งนี้เพียงแค่ตั้งค่าตัวแปรสภาพแวดล้อม WANDB_API_KEY ของคุณและเปลี่ยนแอตทริบิวต์ wandb.project ของ configs/config.yaml (หรือส่งผ่านบรรทัดคำสั่งเช่น python -m train .... wandb.project=s4 )
ตั้งค่า wandb=null เพื่อปิดการบันทึก WANDB
รุ่น Autoregressive สามารถทำได้ด้วยสคริปต์ Generate.py สคริปต์นี้สามารถใช้งานได้สองวิธีหลังจากการฝึกอบรมแบบจำลองโดยใช้ codebase นี้
ตัวเลือกที่ยืดหยุ่นมากขึ้นต้องใช้เส้นทางจุดตรวจของรุ่น Pytorch Lightning ที่ผ่านการฝึกอบรม สคริปต์การสร้างยอมรับตัวเลือกการกำหนดค่าเดียวกันกับสคริปต์รถไฟโดยมีธงเพิ่มเติมสองสามตัวที่บันทึกไว้ใน configs/generate.yaml หลังจากฝึกอบรมกับ python -m train <train flags> ให้สร้างด้วย
python -m generate <train flags> checkpoint_path=<path/to/model.ckpt> <generation flags>
ธงใด ๆ ที่พบในการกำหนดค่าสามารถแทนที่ได้
หมายเหตุ: ตัวเลือกนี้สามารถใช้กับจุดตรวจสอบ .ckpt (Pytorch Lightning ซึ่งรวมถึงข้อมูลสำหรับผู้ฝึกสอน) หรือจุด .pt .
ตัวเลือกที่สองสำหรับการสร้างไม่จำเป็นต้องผ่านการฝึกอบรมอีกครั้งและอ่านการกำหนดค่าจากโฟลเดอร์ Hydra Experiment แทนด้วยจุดตรวจ Lightning Pytorch ภายในโฟลเดอร์ Experiment
ดาวน์โหลดจุดตรวจสอบรุ่น Wikitext-103 ตัวอย่างเช่น ./checkpoints/s4-wt103.pt โมเดลนี้ได้รับการฝึกฝนด้วยคำสั่ง python -m train experiment=lm/s4-wt103 โปรดทราบว่าจากการกำหนดค่าเราจะเห็นว่าแบบจำลองได้รับการฝึกฝนด้วยสนามที่เปิดกว้าง 8192
เพื่อสร้าง
python -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text
สิ่งนี้สร้างตัวอย่างความยาว 16384 ปรับอากาศบนคำนำหน้าความยาว 8192
มาฝึกอบรมรุ่นซาชิมิเล็ก ๆ บนชุดข้อมูล SC09 นอกจากนี้เรายังสามารถลดจำนวนการฝึกอบรมและการตรวจสอบความถูกต้องเพื่อให้ได้จุดตรวจที่เร็วขึ้น:
python -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1
หลังจากยุคแรกเสร็จสมบูรณ์ข้อความจะถูกพิมพ์เพื่อระบุว่าจุดตรวจสอบจะถูกบันทึกไว้ที่ใด
Epoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to "<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt"
ตัวเลือกที่ 1:
python -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=<repository>/outputs/<date>/<time>/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
ตัวเลือกนี้กำหนดค่าการกำหนดค่าแบบเต็มเพื่อให้สามารถสร้างโมเดลและชุดข้อมูลได้
ตัวเลือกที่ 2:
python -m generate experiment_path=<repository>/outputs/<date>/<time> checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000
ตัวเลือกนี้ต้องการเส้นทางไปยังโฟลเดอร์ Hydra Experiment และจุดตรวจสอบที่ต้องการภายใน
configs/ Config files for model, data pipeline, training loop, etc.
data/ Default location of raw data
extensions/ CUDA extensions (Cauchy and Vandermonde kernels)
src/ Main source code for models, datasets, etc.
callbacks/ Training loop utilities (e.g. checkpointing)
dataloaders/ Dataset and dataloader definitions
models/ Model definitions
tasks/ Encoder/decoder modules to interface between data and model backbone
utils/
models/ Model-specific information (code, experiments, additional resources)
example.py Example training script for using S4 externally
train.py Training entrypoint for this repo
generate.py Autoregressive generation script
หากคุณใช้ codebase นี้หรือพบว่างานของเรามีค่าโปรดอ้างอิง S4 และเอกสารอื่น ๆ ที่เกี่ยวข้อง
@inproceedings{gu2022efficiently,
title={Efficiently Modeling Long Sequences with Structured State Spaces},
author={Gu, Albert and Goel, Karan and R'e, Christopher},
booktitle={The International Conference on Learning Representations ({ICLR})},
year={2022}
}