sat ( SwissArmyTransformer ) es una biblioteca flexible y poderosa para desarrollar sus propias variantes de transformador.
sat lleva el nombre de "Swiss Army Shife", lo que significa que todos los modelos (por ejemplo, Bert, GPT, T5, GLM, COGVIEW, VIT ...) comparten el mismo código de columna vertebral y atienden a usos versátiles con algunas mixins adicionales de peso ligero.
sat funciona con deepspeed-ZeRO y un paralelismo del modelo, con el objetivo de proporcionar las mejores prácticas para el pretrete y el fino modelos grandes (parámetros de 100 m ~ 20b).
pip install SwissArmyTransformer
Agregue componentes del modelo-agnóstico , por ejemplo, ajuste de prefijo, ¡en una sola línea!
class ClassificationModel ( GLMModel ): # can also be BertModel, RobertaModel, etc.
def __init__ ( self , args , transformer = None , ** kwargs ):
super (). __init__ ( args , transformer = transformer , ** kwargs )
self . add_mixin ( 'classification_head' , MLPHeadMixin ( args . hidden_size , 2048 , 1 ))
# Arm an arbitrary model with Prefix-tuning with this line!
self . add_mixin ( 'prefix-tuning' , PrefixTuningMixin ( args . num_layers , args . hidden_size // args . num_attention_heads , args . num_attention_heads , args . prefix_len )) model , args = AutoModel . from_pretrained ( 'glm-10b-chinese' , args )
model . add_mixin ( 'auto-regressive' , CachedAutoregressiveMixin ())
# Generate a sequence with beam search
from sat . generation . autoregressive_sampling import filling_sequence
from sat . generation . sampling_strategies import BeamSearchStrategy
output , * mems = filling_sequence ( model , input_seq ,
batch_size = args . batch_size ,
strategy = BeamSearchStrategy ( args . batch_size ))Construya su modelo basado en transformadores con códigos mínimos . Mencionamos GLM, que solo difiere del transformador estándar (llamado BaseModel) en la incrustación de posición (y pérdidas de entrenamiento). Solo necesitamos centrarnos en la parte relacionada al codificar.
class BlockPositionEmbeddingMixin ( BaseMixin ):
# Here define parameters for the mixin
def __init__ ( self , max_sequence_length , hidden_size , init_method_std = 0.02 ):
super ( BlockPositionEmbeddingMixin , self ). __init__ ()
self . max_sequence_length = max_sequence_length
self . hidden_size = hidden_size
self . block_position_embeddings = torch . nn . Embedding ( max_sequence_length , hidden_size )
torch . nn . init . normal_ ( self . block_position_embeddings . weight , mean = 0.0 , std = init_method_std )
# Here define the method for the mixin
def position_embedding_forward ( self , position_ids , ** kwargs ):
position_ids , block_position_ids = position_ids [:, 0 ], position_ids [:, 1 ]
position_embeddings = self . transformer . position_embeddings ( position_ids )
block_position_embeddings = self . block_position_embeddings ( block_position_ids )
return position_embeddings + block_position_embeddings
class GLMModel ( BaseModel ):
def __init__ ( self , args , transformer = None ):
super (). __init__ ( args , transformer = transformer )
self . add_mixin ( 'block_position_embedding' ,
BlockPositionEmbeddingMixin ( args . max_sequence_length , args . hidden_size )
) # Add the mixin for GLM Apoyos integrales para la capacitación . sat tiene como objetivo proporcionar la mejor práctica para el pretratamiento y la Finetuning, donde solo necesita finalizar forward_step y create_dataset_function pero con hiperparámetros para alterar configuraciones de entrenamiento útiles.
--num_nodes , --num_gpus y un hostfile simple.memmap . El archivo de Python más típico para usar Bert en SAT (por inferencia) es el siguiente:
# @File: inference_bert.py
from sat import get_args , get_tokenizer , AutoModel
# Parse args, initialize the environment. This is necessary.
args = get_args ()
# Automatically download and load model. Will also dump model-related hyperparameters to args.
model , args = AutoModel . from_pretrained ( 'bert-base-uncased' , args )
# Get the BertTokenizer according to args.tokenizer_type (automatically set).
tokenizer = get_tokenizer ( args )
# Here to use bert as you want!
# ...Entonces podemos ejecutar el código a través de
SAT_HOME=/path/to/download python inference_bert.py --mode inferenceTodos los nombres de modelos oficialmente compatibles están en URLS.py.
¡Finetune o Pretrate, un transformador también es extremadamente fácil!
# @File: finetune_bert.py
from sat import get_args , get_tokenizer , AutoModel
from sat . model . mixins import MLPHeadMixin
def create_dataset_function ( path , args ):
# Here to load the dataset
# ...
assert isinstance ( dataset , torch . utils . data . Dataset )
return dataset
def forward_step ( data_iterator , model , args , timers ):
inputs = next ( data_iterator ) # from the dataset of create_dataset_function.
loss , * others = model ( inputs )
return loss
# Parse args, initialize the environment. This is necessary.
args = get_args ()
model , args = AutoModel . from_pretrained ( 'bert-base-uncased' , args )
tokenizer = get_tokenizer ( args )
# Here to use bert as you want!
model . del_mixin ( 'bert-final' )
model . add_mixin ( 'classification_head' , MLPHeadMixin ( args . hidden_size , 2048 , 1 ))
# ONE LINE to train!
# args already includes hyperparams such as lr, train-iters, zero-stage ...
training_main ( args ,
model_cls = model ,
forward_step_function = forward_step , # user define
create_dataset_function = create_dataset_function # user define
)Entonces podemos ejecutar el código a través de
deepspeed --include localhost:0,1 finetune_bert.py
--experiment-name ftbert
--mode finetune --train-iters 1000 --save /path/to/save
--train-data /path/to/train --valid-data /path/to/valid
--lr 0.00002 --batch-size 8 --zero-stage 1 --fp16 Aquí usamos datos-paralelo en las GPU 0,1. También podemos lanzar el entrenamiento en muchas máquinas interconectadas a través de --hostfile /path/to/hostfile . Vea el tutorial para más detalles.
Para escribir su propio modelo, solo debe considerar la diferencia entre el transformador estándar. Por ejemplo, si tiene una idea para mejorar la operación de atención:
from sat . model import BaseMixin
class MyAttention ( BaseMixin ):
def __init__ ( self , hidden_size ):
super ( MyAttention , self ). __init__ ()
# MyAttention may needs some new params, e.g. a learnable alpha.
self . learnable_alpha = torch . nn . Parameter ( torch . ones ( hidden_size ))
# This is a hook function, the name `attention_fn` is special.
def attention_fn ( q , k , v , mask , dropout = None , ** kwargs ):
# Code for my attention.
# ...
return attention_results Aquí attention_fn es una función de gancho, reemplazando la acción predeterminada por la nueva función. Todos los ganchos disponibles están en transformer_defaults.py. Ahora podemos usar add_mixin para aplicar nuestro cambio a todos los transformadores, como Bert, Vit y Cogview. Vea el tutorial para más detalles.
Actualmente no tenemos un papel, ¡así que no es necesario citarnos formalmente! ~
Si este proyecto ayuda a su investigación o ingeniería, use footnote{https://github.com/THUDM/SwissArmyTransformer} para mencionarnos y recomendar SwissArmyTransformer a otros.
¡El tutorial para contribuir SAT está en camino!
El proyecto se basa en (un usuario de) Deepeed, Megatron-LM y Huggingface Transformers. Gracias por su increíble trabajo.