Papel: https://arxiv.org/abs/2405.04517
O XLSTM é uma nova arquitetura de rede neural recorrente com base em idéias do LSTM original. Através do bloqueio exponencial com as técnicas apropriadas de normalização e estabilização e uma nova memória da matriz, supera as limitações do LSTM original e mostra desempenho promissor na modelagem de idiomas quando comparado aos modelos de transformadores ou espaço de estado.
Treinamos um modelo de linguagem XLSTM de parâmetros 7B
Otimizamos a arquitetura XLSTM em termos de taxa de transferência e estabilidade de treinamento. O código da arquitetura atualizado está localizado em xlstm/xlstm_large .
Os pesos do modelo estão disponíveis no Huggingface em https://huggingface.co/nx-ai/xlstm-7b.
Crie um ambiente do CONDA a partir do arquivo environment_pt220cu121.yaml . Instale apenas o código do modelo (ou seja, o módulo xlstm ) como pacote:
Instale via PIP:
pip install xlstmClone do Github:
git clone https://github.com/NX-AI/xlstm.git
cd xlstm
pip install -e . Para usar o modelo 7B XLSTM, instale mlstm_kernels via:
pip install mlstm_kernels
Este pacote é baseado em pytorch e foi testado para versões >=1.8 . Para a versão CUDA do SLSTM, você precisa de computação> = 8.0, consulte https://developer.nvidia.com/cuda-gpus. Para um ambiente bem testado, instale o environment_pt220cu121.yaml como:
conda env create -n xlstm -f environment_pt220cu121.yaml
conda activate xlstm Para o modelo XLSTM grande 7B, precisamos do nosso pacote mlstm_kernels (TODO Add Github Link), que fornece kernels rápidos para o XLSTM.
Esta seção explica como usar os modelos do papel XLSTM.
Para aplicativos não idiomas ou para integrar em outras arquiteturas, você pode usar o xLSTMBlockStack e para modelagem de idiomas ou outros aplicativos baseados em token, você pode usar o xLSTMLMModel .
O xLSTMBLockStack é destinado a ser usado como backbone alternativo em projetos existentes. É semelhante a uma pilha de blocos de transformadores, mas usa blocos XLSTM:
import torch
from xlstm import (
xLSTMBlockStack ,
xLSTMBlockStackConfig ,
mLSTMBlockConfig ,
mLSTMLayerConfig ,
sLSTMBlockConfig ,
sLSTMLayerConfig ,
FeedForwardConfig ,
)
cfg = xLSTMBlockStackConfig (
mlstm_block = mLSTMBlockConfig (
mlstm = mLSTMLayerConfig (
conv1d_kernel_size = 4 , qkv_proj_blocksize = 4 , num_heads = 4
)
),
slstm_block = sLSTMBlockConfig (
slstm = sLSTMLayerConfig (
backend = "cuda" ,
num_heads = 4 ,
conv1d_kernel_size = 4 ,
bias_init = "powerlaw_blockdependent" ,
),
feedforward = FeedForwardConfig ( proj_factor = 1.3 , act_fn = "gelu" ),
),
context_length = 256 ,
num_blocks = 7 ,
embedding_dim = 128 ,
slstm_at = [ 1 ],
)
xlstm_stack = xLSTMBlockStack ( cfg )
x = torch . randn ( 4 , 256 , 128 ). to ( "cuda" )
xlstm_stack = xlstm_stack . to ( "cuda" )
y = xlstm_stack ( x )
y . shape == x . shapeSe você estiver trabalhando com strings / arquivos da YAML para configuração, também poderá usar o Dacite para criar o Dataclasses Config. É o mesmo que o trecho acima:
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
from xlstm import xLSTMBlockStack , xLSTMBlockStackConfig
xlstm_cfg = """
mlstm_block:
mlstm:
conv1d_kernel_size: 4
qkv_proj_blocksize: 4
num_heads: 4
slstm_block:
slstm:
backend: cuda
num_heads: 4
conv1d_kernel_size: 4
bias_init: powerlaw_blockdependent
feedforward:
proj_factor: 1.3
act_fn: gelu
context_length: 256
num_blocks: 7
embedding_dim: 128
slstm_at: [1]
"""
cfg = OmegaConf . create ( xlstm_cfg )
cfg = from_dict ( data_class = xLSTMBlockStackConfig , data = OmegaConf . to_container ( cfg ), config = DaciteConfig ( strict = True ))
xlstm_stack = xLSTMBlockStack ( cfg )
x = torch . randn ( 4 , 256 , 128 ). to ( "cuda" )
xlstm_stack = xlstm_stack . to ( "cuda" )
y = xlstm_stack ( x )
y . shape == x . shape O xLSTMLMModel é um invólucro ao redor do xLSTMBlockStack que adiciona a incorporação de token e a cabeça LM.
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
from xlstm import xLSTMLMModel , xLSTMLMModelConfig
xlstm_cfg = """
vocab_size: 50304
mlstm_block:
mlstm:
conv1d_kernel_size: 4
qkv_proj_blocksize: 4
num_heads: 4
slstm_block:
slstm:
backend: cuda
num_heads: 4
conv1d_kernel_size: 4
bias_init: powerlaw_blockdependent
feedforward:
proj_factor: 1.3
act_fn: gelu
context_length: 256
num_blocks: 7
embedding_dim: 128
slstm_at: [1]
"""
cfg = OmegaConf . create ( xlstm_cfg )
cfg = from_dict ( data_class = xLSTMLMModelConfig , data = OmegaConf . to_container ( cfg ), config = DaciteConfig ( strict = True ))
xlstm_stack = xLSTMLMModel ( cfg )
x = torch . randint ( 0 , 50304 , size = ( 4 , 256 )). to ( "cuda" )
xlstm_stack = xlstm_stack . to ( "cuda" )
y = xlstm_stack ( x )
y . shape [ 1 :] == ( 256 , 50304 )Os experimentos sintéticos que mostram os benefícios do SLSTM sobre o MLSTM e o vice-versa Best são a tarefa de paridade e a tarefa de recall associativa multi-query. A tarefa de paridade só pode ser resolvida com recursos de rastreamento de estado fornecidos pela mistura de memória do SLSTM. A tarefa de recordação associativa multi-query mede os recursos de memorização, onde a expansão da memória da matriz e do estado do MLSTM é muito benéfica. Em combinação, eles se saem bem nas duas tarefas.
Para executar cada um, execute o main.py na pasta Experimentos como:
python experiments/main.py --config experiments/parity_xLSTM01.yaml # xLSTM[0:1], sLSTM only
python experiments/main.py --config experiments/parity_xLSTM10.yaml # xLSTM[1:0], mLSTM only
python experiments/main.py --config experiments/parity_xLSTM11.yaml # xLSTM[1:1], mLSTM and sLSTM
Observe que o loop de treinamento não contém parada precoce ou avaliação de teste.
Se você usar esta base de código ou encontrar nosso trabalho valioso, cite o papel XLSTM:
@inproceedings{beck:24xlstm,
title={xLSTM: Extended Long Short-Term Memory},
author={Maximilian Beck and Korbinian Pöppel and Markus Spanring and Andreas Auer and Oleksandra Prudnikova and Michael Kopp and Günter Klambauer and Johannes Brandstetter and Sepp Hochreiter},
booktitle = {Thirty-eighth Conference on Neural Information Processing Systems},
year={2024},
url={https://arxiv.org/abs/2405.04517},
}