Gradient Cache ist eine einfache Technik für unbegrenzte Skalierung kontrastiver Lernstapel weit über die Einschränkung der GPU/TPU -Speicher hinaus. Dies bedeutet, dass ein Training, das früher schwere Hardware einnahm, z. B. 8 V100 GPU, an einer einzelnen GPU durchgeführt werden kann. Darüber hinaus können Benutzer mit Gradientencache Big RAM -GPU/TPU durch viel kostengünstigere hohe Flop -Systeme ersetzen.
Dieses Repo enthält eine generische Implementierung des Gradienten -Cache, das in unserem Papier -Skalierung tief kontrastive Lernstapelgröße unter Speicherlimited -Setup beschrieben wird. Sowohl Pytorch- als auch JAX -Frameworks werden unterstützt.
@inproceedings{gao2021scaling,
title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
author={Luyu Gao, Yunyi Zhang, Jiawei Han, Jamie Callan},
booktitle ={Proceedings of the 6th Workshop on Representation Learning for NLP},
year={2021},
}
NEU: Wir unterstützen jetzt Jax und TPU!
Gradienten -Cache wurde auch in den dichten Durchgang (DPR) integriert. Checkout unser GC-DPR-Toolkit.
Installieren Sie zuerst Ihr gewünschtes Deep Learning Backend, entweder Pytorch oder Jax. Um Gradcache zu installieren, klonen Sie dieses Repo und führen Sie PIP aus.
git clone https://github.com/luyug/GradCache
cd GradCache
pip install .
Für die Entwicklung,
pip install --editable .
Gradienten -Caching -Funktionen werden in der GradCache -Klasse implementiert. Wenn Sie ein neues Projekt entwickeln, anstatt einen alten zu patchen, sehen Sie auch unseren funktionalen Ansatz für einen Aufwand reduzierten Ansatz.
Schauen Sie sich für JAX/Flax -Benutzer hier eine einfache Zugfunktion an.
Die __init__ -Methode der Klasse definiert den Cache und verfügt über mehrere funktionale Parameter *_fn für die einfache Anpassung des Modellverhaltens. Alternativ können Sie auch Gradcache subklassen.
grad_cache.GradCache(
models: List[nn.Module],
chunk_sizes: Union[int, List[int]],
loss_fn: Callable[..., Tensor],
split_input_fn: Callable[[Any, int], Any] = None,
get_rep_fn: Callable[..., Tensor] = None,
fp16: bool = False,
scaler: GradScaler = None,
)
Modelle - Eine Liste von Encodermodellen, die mit dem Gradientencache aktualisiert werden sollen.
Chunk_Sizes - Eine Ganzzahl, die die Chunk -Größe angibt. Oder eine Liste von Ganzzahlen der Stücke für jedes Modell. Dies steuert für jedes Modell die Sub-Batch-Größe, um den Vorwärts-Rücklauf-Pass zu leiten, und sollte basierend auf dem verfügbaren GPU-Speicher festgelegt werden. Ein zu kleiner Wert lässt die GPU unter Verwendung genutzt.
LUST_FN - Eine Verlustfunktion, die Darstellungszahlen der Anzahl entspricht der Anzahl der Modelle in models und willkürlichen Anzahl von Keyword -Argumenten. Es sollte den Verlust basierend auf den Eingangstensoren berechnen und in keinem Fall die Beziehungen der Eingangstensoren im Autograd -Diagramm modifizieren, auf die sich später zum Erstellen des Gradientencache stützen.
Split_input_fn - Eine optionale Funktion, die die generische Modelleingabe in Stücke basierend auf definiertem chunk_gizes aufteilt. Wenn dies nicht zur Verfügung steht, wird diese Klasse ihr Bestes geben, um die Eingaben unterstützter Typen zu teilen. Siehe split_inputs -Funktion.
get_rep_fn - Eine optionale Funktion, die generische Modellausgabe und Rückgabedarstellung Tensoren übernimmt. Wenn nicht vorhanden ist, wird angenommen, dass der generische Ausgang der Darstellungszensor ist.
FP16 - Führen Sie ein gemischtes Präzisionstraining durch, wodurch der Scaler ebenfalls festgelegt werden muss.
Scaler - Ein Gradscaler -Objekt für automatische Mischpräzisionstraining.
Rufen Sie cache_step -Funktion auf, um einen zwischengespeicherten Gradienten -Computatoin -Schritt auszuführen.
cache_step(
*model_inputs,
no_sync_except_last: bool = False,
**loss_kwargs
)
Führen Sie einen einzelnen Gradienten -Cache -Schritt aus. Bei der Funktionsrendite werden Updates für jedes Modell in self.models berechnet. Modelle mit Gradienten, die auf den Gewichten besiedelt sind, als ob die model_inputs als riesige Einzelcharge auf ausreichend großer Hardware ausgeführt werden. Wenn Sie ein Gradcache -Objekt mit __call__ aufrufen, wird auch diese Funktion aufgerufen.
model_inputs - Liste der Eingänge für jedes Encoder -Modell. Sollte in ähnlicher Reihenfolge wie self.models sein.
NO_SYNC_EXTE_LAST -Wenn True unter Distributed Setup für jedes Modell nur die Gradientenreduzierung über die Prozesse für den Vorwärtsrückspass des letzten Sub-Batchs ausgelöst wird. Dies könnte sich beim Umgang mit A) großem Modell und/oder b) nicht triviale Anzahl von Untergebieten ergeben.
LUST_KWARGS - Zusätzliche Keyword -Argumente für die Verlustfunktion loss_fn . Dies soll eine flexible Verlustberechnung (dank dynamischer Graphen in Pytorch) wie Reduktion, Gewichtung usw. ermöglichen. Möglicherweise können Sie mit loss_kwargs Ausgaben aus diesen Encodermodellen einbeziehen, die nicht vom Cache verfolgt werden.
Rückgabe - Verlust, der aktuelle Tensor für den Verlust des Verlusts des Verlusts (vom Diagramm abgetrennt).
model(x) übergebenmodel(*x) übergeben.model(**x) übergeben.model(*x[0], **x[1])Andere generische Eingaben werden nicht vollständig unterstützt. Wir führen den Modellaufruf mit den folgenden Heuristiken durch.
model(*x) übergeben.model(**x)model(*x[0], **x[1]) Um mit ihnen zu laufen, sollte split_input_fn während der Cache -Initialisierung angegeben werden, um diese Eingänge in kleinere Stapel zu unterteilen. In einigen seltenen Fällen müssen Sie möglicherweise auch get_input_tensors überschreiben, wenn seine Heuristik nicht genügend Tensoren schnappen kann, die alle CUDA -Geräte abdecken, die einige Tensoren im Eingang enthalten.
Sagen Sie, wir möchten einen Einbettungsraum von Etiketten und Text lernen. Betrachten Sie die folgenden vier Paare. (In der Praxis haben Sie noch viel mehr und viel längere Texteinträge.)
labels = ['fruit', 'meat', 'school', 'company']
texts = [
'this is an apple',
'steak should be cooked medium rare',
'cmu is pittsburgh',
'apple sells laptop'
]
Initialisieren Sie unsere Encodermodelle,
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
encoder1 = AutoModel.from_pretrained("bert-base-uncased").cuda()
encoder2 = AutoModel.from_pretrained("bert-base-uncased").cuda()
Initialisieren Sie das Gradcache -Objekt,
from grad_cache import GradCache
from grad_cache.loss import SimpleContrastiveLoss
loss_fn = SimpleContrastiveLoss()
gc = GradCache(
models=[encoder1, encoder2],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
Hier verwenden wir das Argument Get_Rep_fn , um eine Funktion anzugeben, die generisches Harmgingface -Modellausgang nimmt und den tatsächlichen Darstellungszensor zurückgibt.
Modelleingabe erstellen,
xx = tokenizer(tt, return_tensors='pt', padding=True)
yy = tokenizer(tt2, return_tensors='pt', padding=True)
Einen Cache -Schritt ausführen,
gc(xx, yy, reduction='mean')
Hier verwenden wir reduction='mean' als LUST_KWARGS , um das Verlustverhalten zu kontrollieren. Mit einem definierten optimizer kann das vollständige Gradienten -Update durchgeführt werden wie.
optimizer.zero_grad()
gc(xx, yy, reduction='mean')
optimizer.step()
Dies wird natürlich durch die (Magie des) dynamischen Graphen behandelt. Sie übergeben flache Kopien desselben Encodermodells an die GradCache -Init -Methode.
tied_encoder = AutoModel.from_pretrained("bert-base-uncased").cuda()
gc = GradCache(
models=[tied_encoder , tied_encoder],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
Unter der Motorhaube werden verschiedene Haken registriert, um die korrekte Gradientenberechnung durchzuführen.
Wir gehen davon aus, dass die Kommunikation der Cross -Prozesse von Repräsentationen vom loss_fn behandelt wird.
from grad_cache.loss import DistributedContrastiveLoss
loss_fn_dist = DistributedContrastiveLoss()
Wickeln Sie die Encodermodelle für die Gradientenreduktion richtig ein,
encoder1_ddp = DistributedDataParallel(
encoder1, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
encoder2_ddp = DistributedDataParallel(
encoder2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
Sie können den Cache initialisieren. Verwenden Sie den verteilten Verlust und die DDP -Modelle.
gc = GradCache(
models=[encoder1_ddp, encoder2_ddp],
chunk_sizes=2,
loss_fn=loss_fn_dist,
get_rep_fn=lambda v: v.pooler_output
)
Einen Cache -Schritt ausführen,
gc(xx, yy, no_sync_except_last=True, reduction='mean')
Setzen Sie no_sync_except_last=True um eine unnötige Gradientenreduzierung zu vermeiden.
Wenn Sie ein neues Projekt entwickeln, empfehlen wir auch, die Dekorateure zu überprüfen, die wir zur Verfügung gestellt haben, um Funktionen für höhere Ordnung für Cache zu erstellen.
grad_cache.functional.cached(func: Callable[..., Tensor])
Ein Dekorateur, der eine Modellanruffunktion in eine zwischengespeicherte kompatible Version nimmt.
Func - Eine Funktion, die den Tensor des Modells und der Rückgabedarstellung aufruft.
Rückgabe - Eine Funktion, die zurückgibt 1) Darstellungsblatt -Tensoren für die Cache -Konstruktion, 2) eine Schließfunktion für den 2. vorwärts und den zwischengespeicherten Gespeicher. Rufen Sie 2) mit 1) als Argument, nachdem Sie den Verlustzensor rückwärts aufgerufen haben.
grad_cache.functional.cat_input_tensor(func: Callable[..., Tensor])
Ein Dekorateur, der Positions- und Schlüsselwortargumente des Typs "Tensor] in einen einzelnen Tensor in der 0. Dimension verkettet. Dies kann nützlich mit den Ergebnissen von Repräsentations -Tensoren aus mehreren nach vorne geänderten Vorwärts gehen.
Func - Eine Verlustfunktion
Rückkehr - Dekorierte Verlustfunktion für zwischengespeicherte Ergebnisse.
grad_cache.functional.gather_input_tensor(func: Callable[..., Tensor], axis=0)
Ein Dekorateur, der die Position und Schlüsselwort-Argumente des Tensors vom Typ All-Gather auf der Achse verkettet. Verwendet, um einen verteilten kontrastiven Lernverlust zu erzeugen.
Func - Eine Verlustfunktion
Rückkehr - Dekorierte Verlustfunktion für verteiltes Training.
Die Funktionsdekoratoren sind besonders nützlich, wenn Ihr Datenlader kleine Chargen ausgibt, aus denen Sie die große Charge konstruieren können. Angenommen, Sie möchten auch automatische gemischte Präzision durchführen. Wir definieren zunächst die Funktionsfunktion und die Verlustfunktion,
from grad_cache.functional import cached, cat_input_tensor
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
@cached
@autocast()
def call_model(model, input):
return model(**input).pooler_output
@cat_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
return F.cross_entropy(scores, target=target)
Angenommen, Sie haben einen DataLoader loader , der kleine Tupelstapel (xx, yy) mit Größe (m * n) ausgibt und dass Sie trainieren möchten, indem Sie 16 kleine Chargen zusammenfassen, um eine Charge von (16m * 16n) zu erhalten.
cache_x = []
cache_y = []
closures_x = []
closures_y = []
for step, sub_batch in enumerate(loader):
xx, yy = sub_batch
rx, cx = call_model(bert, xx)
ry, cy = call_model(bert, yy)
cache_x.append(rx)
cache_y.append(ry)
closuresx.append(cx)
closuresy.append(cy)
if (step + 1) % 16 == 0:
loss = contrastive_loss(cache_x, cache_y)
scaler.scale(loss).backward()
for f, r in zip(closuresx, cache_x):
f(r)
for f, r in zip(closuresy, cache_y):
f(r)
cache_x = []
cache_y = []
closures_x = []
closures_y = []
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Durch das Ausführen von verteiltem Multi-Process-Training müssen: 1) (All-) Darstellungen über Geräte und 2) (All-Reduce) -Abgladienten über Geräte hinweg gesammelt werden. Beide Schritte werden außerhalb der zwischengespeicherten dekorierten Funktionen stattfinden.
Letzteres ist leicht zu erreichen, indem sie Encoder, z. B. bert , in DistributedDataParallel einwickeln.
bert = DistributedDataParallel(
bert, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
Ersteres erfordert zusätzliche verteilte OPs in der Verlustfunktion, die gemäß der ursprünglichen Verlustdefinition erfolgen sollte. Zum Beispiel,
from torch import distributed as dist
from grad_cache.functional import cat_input_tensor, gather_input_tensor
@cat_input_tensor
@gather_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
# scale the loss as DistributedDataParallel will do mean reduce
return F.cross_entropy(scores, target=target) * dist.get_world_size()
Grad_cache/grad_cache.py - Definieren Sie die Gradcache -Klasse. Der Code ist unter 300 Zeilen einschließlich Kommentaren. Für die Entwicklung ermutigen wir Sie, es durchzulesen.
Grad_cache/functional.py - Definieren Sie Dekoratoren, um eine Funktion höherer Ordnung für das Gradienten -Caching aus gewöhnlichen Modellaufruffunktionen und Verlustfunktionen zu erstellen.