sat ( SwissArmyTransformer ) est une bibliothèque flexible et puissante pour développer vos propres variantes de transformateur.
sat est nommé d'après "Swiss Army Knife", ce qui signifie que tous les modèles (par exemple Bert, GPT, T5, GLM, CogView, Vit ...) partagent le même code d'écran et assurent les usages polyvalents avec des mélanges de poids léger supplémentaires.
sat est alimenté par deepspeed-ZeRO et le parallélisme du modèle, visant à fournir les meilleures pratiques pour la pré-formation et les grands modèles de finetun (100m ~ 20B paramètres).
pip install SwissArmyTransformer
Ajoutez des composants autochtone modèles , par exemple le préfixe-tun, en une seule ligne!
class ClassificationModel ( GLMModel ): # can also be BertModel, RobertaModel, etc.
def __init__ ( self , args , transformer = None , ** kwargs ):
super (). __init__ ( args , transformer = transformer , ** kwargs )
self . add_mixin ( 'classification_head' , MLPHeadMixin ( args . hidden_size , 2048 , 1 ))
# Arm an arbitrary model with Prefix-tuning with this line!
self . add_mixin ( 'prefix-tuning' , PrefixTuningMixin ( args . num_layers , args . hidden_size // args . num_attention_heads , args . num_attention_heads , args . prefix_len )) model , args = AutoModel . from_pretrained ( 'glm-10b-chinese' , args )
model . add_mixin ( 'auto-regressive' , CachedAutoregressiveMixin ())
# Generate a sequence with beam search
from sat . generation . autoregressive_sampling import filling_sequence
from sat . generation . sampling_strategies import BeamSearchStrategy
output , * mems = filling_sequence ( model , input_seq ,
batch_size = args . batch_size ,
strategy = BeamSearchStrategy ( args . batch_size ))Créez votre modèle basé sur le transformateur avec un minimum de codes . Nous avons mentionné le GLM, qui ne diffère que du transformateur standard (appelé Basemodel) sur l'intégration de la position (et les pertes d'entraînement). Nous n'avons qu'à nous concentrer sur la partie associée lors du codage.
class BlockPositionEmbeddingMixin ( BaseMixin ):
# Here define parameters for the mixin
def __init__ ( self , max_sequence_length , hidden_size , init_method_std = 0.02 ):
super ( BlockPositionEmbeddingMixin , self ). __init__ ()
self . max_sequence_length = max_sequence_length
self . hidden_size = hidden_size
self . block_position_embeddings = torch . nn . Embedding ( max_sequence_length , hidden_size )
torch . nn . init . normal_ ( self . block_position_embeddings . weight , mean = 0.0 , std = init_method_std )
# Here define the method for the mixin
def position_embedding_forward ( self , position_ids , ** kwargs ):
position_ids , block_position_ids = position_ids [:, 0 ], position_ids [:, 1 ]
position_embeddings = self . transformer . position_embeddings ( position_ids )
block_position_embeddings = self . block_position_embeddings ( block_position_ids )
return position_embeddings + block_position_embeddings
class GLMModel ( BaseModel ):
def __init__ ( self , args , transformer = None ):
super (). __init__ ( args , transformer = transformer )
self . add_mixin ( 'block_position_embedding' ,
BlockPositionEmbeddingMixin ( args . max_sequence_length , args . hidden_size )
) # Add the mixin for GLM Soutien complet à la formation . sat vise à fournir les meilleures pratiques pour la pré-formation et le finetuning, où vous n'avez qu'à terminer forward_step et create_dataset_function mais avec des hyperparamètres pour modifier les configurations de formation utiles.
--num_nodes , --num_gpus et un hostfile simple.memmap , extension automatique et mélange. Le fichier Python le plus typique à utiliser Bert dans SAT (pour l'inférence) est le suivant:
# @File: inference_bert.py
from sat import get_args , get_tokenizer , AutoModel
# Parse args, initialize the environment. This is necessary.
args = get_args ()
# Automatically download and load model. Will also dump model-related hyperparameters to args.
model , args = AutoModel . from_pretrained ( 'bert-base-uncased' , args )
# Get the BertTokenizer according to args.tokenizer_type (automatically set).
tokenizer = get_tokenizer ( args )
# Here to use bert as you want!
# ...Ensuite, nous pouvons exécuter le code via
SAT_HOME=/path/to/download python inference_bert.py --mode inferenceTous les noms de modèle officiellement pris en charge sont dans URLS.py.
Finetune ou présager un transformateur est également extrêmement facile!
# @File: finetune_bert.py
from sat import get_args , get_tokenizer , AutoModel
from sat . model . mixins import MLPHeadMixin
def create_dataset_function ( path , args ):
# Here to load the dataset
# ...
assert isinstance ( dataset , torch . utils . data . Dataset )
return dataset
def forward_step ( data_iterator , model , args , timers ):
inputs = next ( data_iterator ) # from the dataset of create_dataset_function.
loss , * others = model ( inputs )
return loss
# Parse args, initialize the environment. This is necessary.
args = get_args ()
model , args = AutoModel . from_pretrained ( 'bert-base-uncased' , args )
tokenizer = get_tokenizer ( args )
# Here to use bert as you want!
model . del_mixin ( 'bert-final' )
model . add_mixin ( 'classification_head' , MLPHeadMixin ( args . hidden_size , 2048 , 1 ))
# ONE LINE to train!
# args already includes hyperparams such as lr, train-iters, zero-stage ...
training_main ( args ,
model_cls = model ,
forward_step_function = forward_step , # user define
create_dataset_function = create_dataset_function # user define
)Ensuite, nous pouvons exécuter le code via
deepspeed --include localhost:0,1 finetune_bert.py
--experiment-name ftbert
--mode finetune --train-iters 1000 --save /path/to/save
--train-data /path/to/train --valid-data /path/to/valid
--lr 0.00002 --batch-size 8 --zero-stage 1 --fp16 Ici, nous utilisons les données parallèles aux GPU 0,1. Nous pouvons également lancer la formation sur de nombreuses machines interconnectées via --hostfile /path/to/hostfile . Voir le tutoriel pour plus de détails.
Pour écrire votre propre modèle, il vous suffit de considérer la différence entre le transformateur standard. Par exemple, si vous avez une idée d'améliorer le fonctionnement de l'attention:
from sat . model import BaseMixin
class MyAttention ( BaseMixin ):
def __init__ ( self , hidden_size ):
super ( MyAttention , self ). __init__ ()
# MyAttention may needs some new params, e.g. a learnable alpha.
self . learnable_alpha = torch . nn . Parameter ( torch . ones ( hidden_size ))
# This is a hook function, the name `attention_fn` is special.
def attention_fn ( q , k , v , mask , dropout = None , ** kwargs ):
# Code for my attention.
# ...
return attention_results ICI attention_fn est une fonction de crochet, remplaçant l'action par défaut par la nouvelle fonction. Tous les crochets disponibles se trouvent dans Transformer_Default.py. Nous pouvons maintenant utiliser add_mixin pour appliquer notre changement à tous les transformateurs, tels que Bert, Vit et CogView. Voir le tutoriel pour plus de détails.
Actuellement, nous n'avons pas de papier, vous n'avez donc pas besoin de nous citer formellement! ~
Si ce projet aide votre recherche ou votre ingénierie, utilisez footnote{https://github.com/THUDM/SwissArmyTransformer} pour nous mentionner et recommander SwissArmyTransformer à d'autres.
Le tutoriel pour contribuer SAT est en route!
Le projet est basé sur (un utilisateur de) Deeppeed, Megatron-LM et HuggingFace Transformers. Merci pour leur travail formidable.