Le cache de gradient est une technique simple pour la mise à l'échelle d'un lot d'apprentissage contrasté illimité bien au-delà de la contrainte de mémoire GPU / TPU. Cela signifie que la formation qui prenait du matériel lourd, par exemple GPU V100, peut être effectuée sur un seul GPU. De plus, le cache de gradient permet aux utilisateurs de remplacer le grand GPU / TPU RAM avec des systèmes de RAM faibles à haut débit beaucoup plus rentables.
Ce repo contient une implémentation générique du cache de gradient décrit dans notre échelle de papier à l'échelle de la taille du lot d'apprentissage en profondeur sous la configuration limitée de mémoire. Les cadres Pytorch et Jax sont pris en charge.
@inproceedings{gao2021scaling,
title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
author={Luyu Gao, Yunyi Zhang, Jiawei Han, Jamie Callan},
booktitle ={Proceedings of the 6th Workshop on Representation Learning for NLP},
year={2021},
}
NOUVEAU: Nous soutenons maintenant Jax et TPU!
Le cache de gradient a également été intégré à la récupération de passage dense (DPR). Découvrez notre boîte à outils GC-DPR.
Installez d'abord le backend en profondeur souhaité, soit Pytorch ou Jax. Pour installer GradCache, clone ce repo et exécuter PIP.
git clone https://github.com/luyug/GradCache
cd GradCache
pip install .
Pour le développement,
pip install --editable .
Les fonctionnalités de mise en cache de gradient sont implémentées en classe GradCache . Si vous développez un nouveau projet au lieu de corriger un ancien, consultez également notre approche fonctionnelle pour une approche réduite en effort.
Pour l'utilisateur de Jax / Flax, jetez un œil à une fonction de train simple ici.
La méthode __init__ de la classe définit le cache et a plusieurs paramètres fonctionnels *_fn pour un ajustement facile des comportements du modèle. Alternativement, vous pouvez également sous-classe GradCache.
grad_cache.GradCache(
models: List[nn.Module],
chunk_sizes: Union[int, List[int]],
loss_fn: Callable[..., Tensor],
split_input_fn: Callable[[Any, int], Any] = None,
get_rep_fn: Callable[..., Tensor] = None,
fp16: bool = False,
scaler: GradScaler = None,
)
Modèles - Une liste de modèles d'encodeur à mettre à jour avec le cache de gradient.
Chunk_sizes - un entier indiquant la taille du morceau. Ou une liste d'entiers de taille de morceau pour chaque modèle. Cet contrôle pour chaque modèle de la taille de sous-lots pour exécuter le passage vers l'avant et doit être défini en fonction de la mémoire GPU disponible. Une valeur trop petite laissera le GPU sous utilisation.
Loss_fn - une fonction de perte qui prend les tenseurs de représentation du nombre égal au nombre de modèles dans models et les nombres arbitraires d'arguments de mots clés. Il doit calculer la perte en fonction des tenseurs d'entrée et en aucun cas, modifiez les relations des tenseurs d'entrée dans le graphique Autograd, qui sont ultérieurement invoquées pour créer le cache de gradient.
Split_input_fn - Une fonction facultative qui divise l'entrée du modèle générique en morceaux basés sur Chunk_sizes définis. S'il n'est pas fourni, cette classe fera de son mieux pour diviser les entrées des types pris en charge. Voir Fonction split_inputs .
get_rep_fn - Une fonction facultative qui prend les tenseurs de sortie de modèle générique et le retour. S'il n'est pas fourni, la sortie générique est supposée être le tenseur de représentation.
FP16 - Si vrai, exécutez une formation de précision mixte, qui nécessite également que SCAUR soit également défini.
SCALER - Un objet GradScaleur pour une formation automatique de précision mixte.
Pour exécuter une étape de compromis de gradient caché, appelez la fonction cache_step ,
cache_step(
*model_inputs,
no_sync_except_last: bool = False,
**loss_kwargs
)
Exécutez une étape de cache de gradient unique. Lors du retour de la fonction, les mises à jour sont calculées pour chaque modèle de self.models avec gradient peuplé sur les poids, comme si les model_inputs sont exécutés comme un énorme lot unique sur du matériel suffisamment grand. Appeler un objet GradCache avec __call__ invoque également cette fonction.
Model_Inputs - Liste des entrées de chaque modèle d'encodeur. Devrait être dans un ordre similaire à self.models .
NO_SYNC_EXT_LAST - Si vrai, sous la configuration distribuée, pour chaque modèle, déclenchez uniquement la réduction du gradient à l'autre à travers les processus pour la passe avant-arrière du dernier sous-lot. Cela pourrait être utile lorsqu'il s'agit de a) grand modèle, et / ou b) un nombre non trivial de sous-dossiers.
Loss_kwargs - Arguments de mots clés supplémentaires à la fonction de perte loss_fn . Ceci est destiné à permettre un calcul de perte flexible (grâce au graphique dynamique dans Pytorch) tels que la réduction, la pondération, etc. Potentiellement, en utilisant loss_kwargs , vous pouvez incorporer les sorties de ces modèles d'encodeur non suivis par le cache.
Retour - Perte, le tenseur de balayage de perte d'étapes de courant (détaché du graphique).
model(x)model(*x)model(**x)model(*x[0], **x[1])D'autres entrées génériques ne sont pas entièrement prises en charge, nous effectuons un modèle de modèle en utilisant l'heuristique suivante,
model(*x)model(**x)model(*x[0], **x[1]) Pour s'exécuter avec eux, split_input_fn doit être spécifié lors de l'initialisation du cache pour diviser ces entrées en lots plus petits. Dans certains cas rares, vous devrez peut-être également remplacer get_input_tensors lorsque son heuristique ne peut pas saisir suffisamment de tenseurs qui couvrent tous les appareils CUDA qui contiennent certains tenseurs dans l'entrée.
Disons que nous voulons apprendre un espace d'incorporation d'étiquettes et de texte. Considérez les quatre paires suivantes. (En pratique, vous aurez des entrées de texte de bien plus et de bien plus longues.)
labels = ['fruit', 'meat', 'school', 'company']
texts = [
'this is an apple',
'steak should be cooked medium rare',
'cmu is pittsburgh',
'apple sells laptop'
]
Initialiser nos modèles d'encodeur,
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
encoder1 = AutoModel.from_pretrained("bert-base-uncased").cuda()
encoder2 = AutoModel.from_pretrained("bert-base-uncased").cuda()
Initialiser l'objet GradCache,
from grad_cache import GradCache
from grad_cache.loss import SimpleContrastiveLoss
loss_fn = SimpleContrastiveLoss()
gc = GradCache(
models=[encoder1, encoder2],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
Ici, nous utilisons l'argument GET_REP_FN pour spécifier une fonction qui prend la sortie du modèle Generic HuggingFace et renvoie le tenseur de représentation réel.
Créer une entrée de modèle,
xx = tokenizer(tt, return_tensors='pt', padding=True)
yy = tokenizer(tt2, return_tensors='pt', padding=True)
Exécutez une étape de cache,
gc(xx, yy, reduction='mean')
Ici, nous utilisons reduction='mean' comme perte_kwargs pour contrôler le comportement de perte. Avec un optimizer défini, la mise à jour complète du gradient peut être effectuée comme,
optimizer.zero_grad()
gc(xx, yy, reduction='mean')
optimizer.step()
Ceci est naturellement géré par le graphique dynamique (magique de). Vous passez des copies peu profondes du même modèle d'encodeur à la méthode GradCache init.
tied_encoder = AutoModel.from_pretrained("bert-base-uncased").cuda()
gc = GradCache(
models=[tied_encoder , tied_encoder],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
Sous le capot, des crochets distincts seront enregistrés pour faire un calcul de gradient correct.
Nous nous attendons à ce que la communication transversale des représentations soit gérée par la loss_fn .
from grad_cache.loss import DistributedContrastiveLoss
loss_fn_dist = DistributedContrastiveLoss()
Enveloppez correctement les modèles d'encodeur pour la réduction du gradient,
encoder1_ddp = DistributedDataParallel(
encoder1, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
encoder2_ddp = DistributedDataParallel(
encoder2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
Vous pouvez initialiser le cache à utiliser la perte distribuée et les modèles DDP,
gc = GradCache(
models=[encoder1_ddp, encoder2_ddp],
chunk_sizes=2,
loss_fn=loss_fn_dist,
get_rep_fn=lambda v: v.pooler_output
)
Exécutez une étape de cache,
gc(xx, yy, no_sync_except_last=True, reduction='mean')
Définissez no_sync_except_last=True pour éviter la réduction du gradient inutile.
Si vous développez un nouveau projet, nous vous recommandons également de consulter les décorateurs que nous avons fournis pour créer des fonctions d'ordre supérieur pour le cache.
grad_cache.functional.cached(func: Callable[..., Tensor])
Un décorateur qui prend une fonction d'appel modèle dans une version compatible en cache.
Func - Une fonction qui appelle le modèle de représentation du modèle et de retour.
Retour - Une fonction qui retourne 1) Tenseurs de feuilles de représentation pour la construction du cache, 2) une fonction de fermeture pour le 2ème avant et la mise en cache vers l'arrière. Appelez 2) avec 1) comme argument après avoir appelé en arrière sur le tenseur de perte.
grad_cache.functional.cat_input_tensor(func: Callable[..., Tensor])
Un décorateur qui concaténe les arguments positionnels et de mots clés de la liste des types [tenseur] en un seul tenseur sur la 0e dimension. Cela peut être utile traitant des résultats des tenseurs de représentation de plusieurs en avant en cache.
Func - une fonction de perte
Retour - Fonction de perte décorée pour les résultats mis en cache.
grad_cache.functional.gather_input_tensor(func: Callable[..., Tensor], axis=0)
Un décorateur qui a des arguments positionnels et de mots clés entiers de type tenseur et les concaténer sur l'axe. Destiné à être utilisé pour créer une perte d'apprentissage contrastée distribuée.
Func - une fonction de perte
Retour - Fonction de perte décorée pour la formation distribuée.
Les décorateurs fonctionnels sont particulièrement utiles si votre chargeur de données émet de petits lots, à partir desquels vous pouvez construire le grand lot. Disons que vous souhaitez également faire une précision mixte automatique, nous définissons d'abord la fonction d'appel du modèle et la fonction de perte,
from grad_cache.functional import cached, cat_input_tensor
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
@cached
@autocast()
def call_model(model, input):
return model(**input).pooler_output
@cat_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
return F.cross_entropy(scores, target=target)
Dites que vous avez un loader de dataloader émettant de petits lots de tuple (xx, yy) de taille (m * n) et que vous voulez vous entraîner en agrégeant 16 petits lots pour obtenir un lot de (16m * 16n),
cache_x = []
cache_y = []
closures_x = []
closures_y = []
for step, sub_batch in enumerate(loader):
xx, yy = sub_batch
rx, cx = call_model(bert, xx)
ry, cy = call_model(bert, yy)
cache_x.append(rx)
cache_y.append(ry)
closuresx.append(cx)
closuresy.append(cy)
if (step + 1) % 16 == 0:
loss = contrastive_loss(cache_x, cache_y)
scaler.scale(loss).backward()
for f, r in zip(closuresx, cache_x):
f(r)
for f, r in zip(closuresy, cache_y):
f(r)
cache_x = []
cache_y = []
closures_x = []
closures_y = []
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
L'exécution de la formation multi-traitement distribuée nécessite: 1) (toutes) rassemblent des représentations sur les appareils et 2) les gradients (All-Reduce) sur les appareils. Les deux étapes se produiront à l'extérieur des fonctions décorées en cache.
Ce dernier est facile à réaliser en emballage les encodeurs, par exemple un bert , dans DistributedDataParallel .
bert = DistributedDataParallel(
bert, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
Le premier nécessite des OPS distribués supplémentaires dans la fonction de perte, ce qui doit être effectué en fonction de la définition de perte d'origine. Par exemple,
from torch import distributed as dist
from grad_cache.functional import cat_input_tensor, gather_input_tensor
@cat_input_tensor
@gather_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
# scale the loss as DistributedDataParallel will do mean reduce
return F.cross_entropy(scores, target=target) * dist.get_world_size()
grad_cache / grad_cache.py - Définissez la classe GradCache. Le code est inférieur à 300 lignes, y compris les commentaires. Pour le développement, nous vous encourageons à le lire.
grad_cache / functional.py - Définissez les décorateurs pour créer une fonction d'ordre supérieur pour la mise en cache de gradient à partir des fonctions d'appel de modèle ordinaires et des fonctions de perte.