sat ( SwissArmyTransformer ) adalah perpustakaan yang fleksibel dan kuat untuk mengembangkan varian transformator Anda sendiri.
sat dinamai "Pisau Angkatan Darat Swiss", yang berarti bahwa semua model (misalnya Bert, GPT, T5, GLM, Cogview, Vit ...) berbagi kode tulang punggung yang sama dan memenuhi penggunaan serbaguna dengan beberapa mixin ringan ekstra.
sat didukung oleh paralelisme deepspeed-ZeRO dan model, yang bertujuan untuk memberikan praktik terbaik untuk pretraining dan finetuning model besar (parameter 100m ~ 20b).
pip install SwissArmyTransformer
Tambahkan komponen model-agnostik , misalnya tuning awalan, hanya dalam satu baris!
class ClassificationModel ( GLMModel ): # can also be BertModel, RobertaModel, etc.
def __init__ ( self , args , transformer = None , ** kwargs ):
super (). __init__ ( args , transformer = transformer , ** kwargs )
self . add_mixin ( 'classification_head' , MLPHeadMixin ( args . hidden_size , 2048 , 1 ))
# Arm an arbitrary model with Prefix-tuning with this line!
self . add_mixin ( 'prefix-tuning' , PrefixTuningMixin ( args . num_layers , args . hidden_size // args . num_attention_heads , args . num_attention_heads , args . prefix_len )) model , args = AutoModel . from_pretrained ( 'glm-10b-chinese' , args )
model . add_mixin ( 'auto-regressive' , CachedAutoregressiveMixin ())
# Generate a sequence with beam search
from sat . generation . autoregressive_sampling import filling_sequence
from sat . generation . sampling_strategies import BeamSearchStrategy
output , * mems = filling_sequence ( model , input_seq ,
batch_size = args . batch_size ,
strategy = BeamSearchStrategy ( args . batch_size ))Bangun model berbasis transformator Anda dengan kode minimal . Kami menyebutkan GLM, yang hanya berbeda dari transformator standar (disebut basemodel) pada penyematan posisi (dan kerugian pelatihan). Kita hanya perlu fokus pada bagian terkait saat pengkodean.
class BlockPositionEmbeddingMixin ( BaseMixin ):
# Here define parameters for the mixin
def __init__ ( self , max_sequence_length , hidden_size , init_method_std = 0.02 ):
super ( BlockPositionEmbeddingMixin , self ). __init__ ()
self . max_sequence_length = max_sequence_length
self . hidden_size = hidden_size
self . block_position_embeddings = torch . nn . Embedding ( max_sequence_length , hidden_size )
torch . nn . init . normal_ ( self . block_position_embeddings . weight , mean = 0.0 , std = init_method_std )
# Here define the method for the mixin
def position_embedding_forward ( self , position_ids , ** kwargs ):
position_ids , block_position_ids = position_ids [:, 0 ], position_ids [:, 1 ]
position_embeddings = self . transformer . position_embeddings ( position_ids )
block_position_embeddings = self . block_position_embeddings ( block_position_ids )
return position_embeddings + block_position_embeddings
class GLMModel ( BaseModel ):
def __init__ ( self , args , transformer = None ):
super (). __init__ ( args , transformer = transformer )
self . add_mixin ( 'block_position_embedding' ,
BlockPositionEmbeddingMixin ( args . max_sequence_length , args . hidden_size )
) # Add the mixin for GLM Dukungan komprehensif untuk pelatihan . sat bertujuan untuk memberikan praktik terbaik untuk pretraining dan finetuning, di mana Anda hanya perlu menyelesaikan forward_step dan create_dataset_function tetapi dengan hyperparameters untuk mengubah konfigurasi pelatihan yang bermanfaat.
--num_nodes , --num_gpus dan hostfile sederhana.memmap . File Python yang paling khas untuk menggunakan Bert di SAT (untuk inferensi) adalah sebagai berikut:
# @File: inference_bert.py
from sat import get_args , get_tokenizer , AutoModel
# Parse args, initialize the environment. This is necessary.
args = get_args ()
# Automatically download and load model. Will also dump model-related hyperparameters to args.
model , args = AutoModel . from_pretrained ( 'bert-base-uncased' , args )
# Get the BertTokenizer according to args.tokenizer_type (automatically set).
tokenizer = get_tokenizer ( args )
# Here to use bert as you want!
# ...Kemudian kita dapat menjalankan kode melalui
SAT_HOME=/path/to/download python inference_bert.py --mode inferenceSemua nama model yang didukung secara resmi ada di urls.py.
Untuk finetune atau pretrain transformator juga sangat mudah!
# @File: finetune_bert.py
from sat import get_args , get_tokenizer , AutoModel
from sat . model . mixins import MLPHeadMixin
def create_dataset_function ( path , args ):
# Here to load the dataset
# ...
assert isinstance ( dataset , torch . utils . data . Dataset )
return dataset
def forward_step ( data_iterator , model , args , timers ):
inputs = next ( data_iterator ) # from the dataset of create_dataset_function.
loss , * others = model ( inputs )
return loss
# Parse args, initialize the environment. This is necessary.
args = get_args ()
model , args = AutoModel . from_pretrained ( 'bert-base-uncased' , args )
tokenizer = get_tokenizer ( args )
# Here to use bert as you want!
model . del_mixin ( 'bert-final' )
model . add_mixin ( 'classification_head' , MLPHeadMixin ( args . hidden_size , 2048 , 1 ))
# ONE LINE to train!
# args already includes hyperparams such as lr, train-iters, zero-stage ...
training_main ( args ,
model_cls = model ,
forward_step_function = forward_step , # user define
create_dataset_function = create_dataset_function # user define
)Kemudian kita dapat menjalankan kode melalui
deepspeed --include localhost:0,1 finetune_bert.py
--experiment-name ftbert
--mode finetune --train-iters 1000 --save /path/to/save
--train-data /path/to/train --valid-data /path/to/valid
--lr 0.00002 --batch-size 8 --zero-stage 1 --fp16 Di sini kami menggunakan data-paralel pada GPU 0,1. Kami juga dapat meluncurkan pelatihan di banyak mesin yang saling terhubung melalui --hostfile /path/to/hostfile . Lihat tutorial untuk detail lebih lanjut.
Untuk menulis model Anda sendiri, Anda hanya perlu mempertimbangkan perbedaan antara transformator standar. Misalnya, jika Anda memiliki ide untuk meningkatkan operasi perhatian:
from sat . model import BaseMixin
class MyAttention ( BaseMixin ):
def __init__ ( self , hidden_size ):
super ( MyAttention , self ). __init__ ()
# MyAttention may needs some new params, e.g. a learnable alpha.
self . learnable_alpha = torch . nn . Parameter ( torch . ones ( hidden_size ))
# This is a hook function, the name `attention_fn` is special.
def attention_fn ( q , k , v , mask , dropout = None , ** kwargs ):
# Code for my attention.
# ...
return attention_results Di sini attention_fn adalah fungsi kait, mengganti tindakan default oleh fungsi baru. Semua kait yang tersedia ada di transformer_defaults.py. Sekarang kita dapat menggunakan add_mixin untuk menerapkan perubahan kita pada semua transformator, seperti Bert, Vit dan Cogview. Lihat tutorial untuk detail lebih lanjut.
Saat ini kami tidak memiliki kertas, jadi Anda tidak perlu mengutip kami secara resmi! ~
Jika proyek ini membantu penelitian atau rekayasa Anda, gunakan footnote{https://github.com/THUDM/SwissArmyTransformer} untuk menyebut kami dan merekomendasikan SwissArmyTransformer kepada orang lain.
Tutorial untuk SAT yang berkontribusi sedang dalam perjalanan!
Proyek ini didasarkan pada (pengguna) Deepspeed, Megatron-LM dan Huggingface Transformers. Terima kasih atas pekerjaan luar biasa mereka.