Kertas: https://arxiv.org/abs/2405.04517
XLSTM adalah arsitektur jaringan saraf berulang baru berdasarkan ide -ide LSTM asli. Melalui gating eksponensial dengan teknik normalisasi dan stabilisasi yang tepat dan memori matriks baru itu mengatasi keterbatasan LSTM asli dan menunjukkan kinerja yang menjanjikan pada pemodelan bahasa bila dibandingkan dengan transformator atau model ruang negara.
Kami melatih model bahasa XLSTM parameter 7b
Kami telah mengoptimalkan arsitektur XLSTM dalam hal throughput dan stabilitas pelatihan. Kode untuk arsitektur yang diperbarui terletak di xlstm/xlstm_large .
Bobot model tersedia di Huggingface di https://huggingface.co/nx-ai/xlstm-7b.
Buat lingkungan conda dari file environment_pt220cu121.yaml . Instal Kode Model saja (yaitu Modul xlstm ) sebagai paket:
Instal via PIP:
pip install xlstmKlon dari GitHub:
git clone https://github.com/NX-AI/xlstm.git
cd xlstm
pip install -e . Untuk menggunakan model 7b XLSTM Instal mlstm_kernels via:
pip install mlstm_kernels
Paket ini didasarkan pada pytorch dan diuji untuk versi >=1.8 . Untuk SLSTM versi CUDA, Anda memerlukan kemampuan menghitung> = 8.0, lihat https://developer.nvidia.com/cuda-gpus. Untuk lingkungan yang teruji dengan baik, pasang environment_pt220cu121.yaml sebagai:
conda env create -n xlstm -f environment_pt220cu121.yaml
conda activate xlstm Untuk model XLSTM Besar 7B, kami memerlukan paket mlstm_kernels (TODO ADD GITHUB Link) kami, yang menyediakan kernel cepat untuk XLSTM.
Bagian ini menjelaskan cara menggunakan model dari kertas XLSTM.
Untuk aplikasi non bahasa atau untuk mengintegrasikan arsitektur lain, Anda dapat menggunakan xLSTMBlockStack dan untuk pemodelan bahasa atau aplikasi berbasis token lainnya, Anda dapat menggunakan xLSTMLMModel .
xLSTMBLockStack dimaksudkan untuk digunakan sebagai tulang punggung alternatif dalam proyek yang ada. Ini mirip dengan tumpukan blok transformator, tetapi menggunakan blok XLSTM:
import torch
from xlstm import (
xLSTMBlockStack ,
xLSTMBlockStackConfig ,
mLSTMBlockConfig ,
mLSTMLayerConfig ,
sLSTMBlockConfig ,
sLSTMLayerConfig ,
FeedForwardConfig ,
)
cfg = xLSTMBlockStackConfig (
mlstm_block = mLSTMBlockConfig (
mlstm = mLSTMLayerConfig (
conv1d_kernel_size = 4 , qkv_proj_blocksize = 4 , num_heads = 4
)
),
slstm_block = sLSTMBlockConfig (
slstm = sLSTMLayerConfig (
backend = "cuda" ,
num_heads = 4 ,
conv1d_kernel_size = 4 ,
bias_init = "powerlaw_blockdependent" ,
),
feedforward = FeedForwardConfig ( proj_factor = 1.3 , act_fn = "gelu" ),
),
context_length = 256 ,
num_blocks = 7 ,
embedding_dim = 128 ,
slstm_at = [ 1 ],
)
xlstm_stack = xLSTMBlockStack ( cfg )
x = torch . randn ( 4 , 256 , 128 ). to ( "cuda" )
xlstm_stack = xlstm_stack . to ( "cuda" )
y = xlstm_stack ( x )
y . shape == x . shapeJika Anda bekerja dengan string / file YAML untuk konfigurasi, Anda juga dapat menggunakan Dacite untuk membuat Config Dataclasses. Ini sama dengan cuplikan di atas:
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
from xlstm import xLSTMBlockStack , xLSTMBlockStackConfig
xlstm_cfg = """
mlstm_block:
mlstm:
conv1d_kernel_size: 4
qkv_proj_blocksize: 4
num_heads: 4
slstm_block:
slstm:
backend: cuda
num_heads: 4
conv1d_kernel_size: 4
bias_init: powerlaw_blockdependent
feedforward:
proj_factor: 1.3
act_fn: gelu
context_length: 256
num_blocks: 7
embedding_dim: 128
slstm_at: [1]
"""
cfg = OmegaConf . create ( xlstm_cfg )
cfg = from_dict ( data_class = xLSTMBlockStackConfig , data = OmegaConf . to_container ( cfg ), config = DaciteConfig ( strict = True ))
xlstm_stack = xLSTMBlockStack ( cfg )
x = torch . randn ( 4 , 256 , 128 ). to ( "cuda" )
xlstm_stack = xlstm_stack . to ( "cuda" )
y = xlstm_stack ( x )
y . shape == x . shape xLSTMLMModel adalah pembungkus di sekitar xLSTMBlockStack yang menambahkan token embedding dan LM head.
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
from xlstm import xLSTMLMModel , xLSTMLMModelConfig
xlstm_cfg = """
vocab_size: 50304
mlstm_block:
mlstm:
conv1d_kernel_size: 4
qkv_proj_blocksize: 4
num_heads: 4
slstm_block:
slstm:
backend: cuda
num_heads: 4
conv1d_kernel_size: 4
bias_init: powerlaw_blockdependent
feedforward:
proj_factor: 1.3
act_fn: gelu
context_length: 256
num_blocks: 7
embedding_dim: 128
slstm_at: [1]
"""
cfg = OmegaConf . create ( xlstm_cfg )
cfg = from_dict ( data_class = xLSTMLMModelConfig , data = OmegaConf . to_container ( cfg ), config = DaciteConfig ( strict = True ))
xlstm_stack = xLSTMLMModel ( cfg )
x = torch . randint ( 0 , 50304 , size = ( 4 , 256 )). to ( "cuda" )
xlstm_stack = xlstm_stack . to ( "cuda" )
y = xlstm_stack ( x )
y . shape [ 1 :] == ( 256 , 50304 )Eksperimen sintetis menunjukkan manfaat dari SLSTM dibandingkan MLSTM dan sebaliknya adalah tugas paritas dan tugas penarikan asosiatif multi-kuerinya. Tugas paritas hanya dapat diselesaikan dengan kemampuan pelacakan negara yang disediakan oleh pencampuran memori SLSTM. Tugas penarikan asosiatif multi-kuerinya mengukur kemampuan menghafal, di mana ekspansi matriks-memori dan negara bagian MLSTM sangat bermanfaat. Dalam kombinasi mereka melakukannya dengan baik pada kedua tugas.
Untuk menjalankan masing -masing, jalankan main.py di folder Eksperimen seperti:
python experiments/main.py --config experiments/parity_xLSTM01.yaml # xLSTM[0:1], sLSTM only
python experiments/main.py --config experiments/parity_xLSTM10.yaml # xLSTM[1:0], mLSTM only
python experiments/main.py --config experiments/parity_xLSTM11.yaml # xLSTM[1:1], mLSTM and sLSTM
Perhatikan bahwa loop pelatihan tidak mengandung penghentian awal atau evaluasi tes.
Jika Anda menggunakan basis kode ini, atau temukan pekerjaan kami yang berharga, silakan kutip kertas XLSTM:
@inproceedings{beck:24xlstm,
title={xLSTM: Extended Long Short-Term Memory},
author={Maximilian Beck and Korbinian Pöppel and Markus Spanring and Andreas Auer and Oleksandra Prudnikova and Michael Kopp and Günter Klambauer and Johannes Brandstetter and Sepp Hochreiter},
booktitle = {Thirty-eighth Conference on Neural Information Processing Systems},
year={2024},
url={https://arxiv.org/abs/2405.04517},
}