Cache Gradient adalah teknik sederhana untuk penskalaan batch pembelajaran kontras yang tidak terbatas jauh melampaui batasan memori GPU/TPU. Ini berarti pelatihan yang digunakan untuk mengambil perangkat keras yang berat, misalnya GPU 8 V100, dapat dilakukan pada satu GPU. Selain itu, cache gradien memungkinkan pengguna untuk mengganti RAM besar GPU/TPU dengan sistem RAM rendah flop rendah yang jauh lebih efisien.
Repo ini memiliki implementasi generik dari cache gradien yang dijelaskan dalam kertas kami penskalaan ukuran pembelajaran kontras dalam di bawah pengaturan terbatas memori. Kerangka kerja Pytorch dan Jax didukung.
@inproceedings{gao2021scaling,
title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
author={Luyu Gao, Yunyi Zhang, Jiawei Han, Jamie Callan},
booktitle ={Proceedings of the 6th Workshop on Representation Learning for NLP},
year={2021},
}
Baru: Kami sekarang mendukung Jax dan TPU!
Cache gradien juga telah diintegrasikan ke dalam pengambilan lorong padat (DPR). Periksa toolkit GC-DPR kami.
Pertama -tama pasang backend pembelajaran mendalam yang Anda inginkan, baik Pytorch atau Jax. Untuk menginstal GradCache, klon repo ini dan jalankan PIP.
git clone https://github.com/luyug/GradCache
cd GradCache
pip install .
Untuk pengembangan,
pip install --editable .
Fungsi caching gradien diimplementasikan di kelas GradCache . Jika Anda mengembangkan proyek baru alih -alih menambal yang lama, lihat juga pendekatan fungsional kami untuk upaya mengurangi pendekatan.
Untuk pengguna Jax/Flax, lihatlah fungsi kereta sederhana di sini.
Metode __init__ kelas mendefinisikan cache dan memiliki beberapa parameter fungsional *_fn agar mudah menyesuaikan perilaku model. Atau Anda juga dapat Sub-Kelas Gradcache.
grad_cache.GradCache(
models: List[nn.Module],
chunk_sizes: Union[int, List[int]],
loss_fn: Callable[..., Tensor],
split_input_fn: Callable[[Any, int], Any] = None,
get_rep_fn: Callable[..., Tensor] = None,
fp16: bool = False,
scaler: GradScaler = None,
)
Model - Daftar model encoder yang akan diperbarui dengan cache gradien.
chunk_sizes - bilangan bulat yang menunjukkan ukuran chunk. Atau daftar bilangan bulat ukuran chunk untuk setiap model. Ini mengontrol untuk setiap model ukuran sub-batch untuk menjalankan Pass-Backward Forward dan harus ditetapkan berdasarkan memori GPU yang tersedia. Nilai yang terlalu kecil akan meninggalkan GPU yang digunakan.
Loss_fn - Fungsi kerugian yang mengambil tensor representasi angka yang sama dengan jumlah model dalam models dan jumlah argumen argumen kata kunci yang sewenang -wenang. Ini harus menghitung kerugian berdasarkan pada tensor input, dan dalam hal apa pun tidak mengubah hubungan input tensor dalam grafik autograd, yang kemudian diandalkan untuk membuat cache gradien.
split_input_fn - Fungsi opsional yang membagi input model generik menjadi potongan -potongan berdasarkan chunk_sizes yang ditentukan. Jika tidak disediakan, kelas ini akan mencoba yang terbaik untuk membagi input jenis yang didukung. Lihat fungsi split_inputs .
get_rep_fn - Fungsi opsional yang mengambil output model generik dan pengembalian tensor representasi. Jika tidak disediakan, output generik diasumsikan sebagai tensor representasi.
FP16 - Jika benar, jalankan pelatihan presisi campuran, yang mengharuskan scaler juga diatur.
Scaler - Objek GradScaler untuk pelatihan presisi campuran otomatis.
Untuk menjalankan langkah komputoin gradien yang di -cache, hubungi fungsi cache_step ,
cache_step(
*model_inputs,
no_sync_except_last: bool = False,
**loss_kwargs
)
Jalankan langkah cache gradien tunggal. Setelah fungsi kembali, pembaruan dihitung untuk setiap model dalam self.models Model dengan gradien yang dihuni pada bobot, seolah -olah model_inputs dijalankan sebagai batch tunggal besar pada perangkat keras yang cukup besar. Memanggil objek lulusan dengan __call__ juga akan memohon fungsi ini.
Model_inputs - Daftar input untuk setiap model encoder. Harus dalam urutan yang sama dengan self.models .
no_sync_except_last -jika benar, di bawah pengaturan terdistribusi, untuk setiap model, hanya memicu pengurangan gradien di seluruh proses untuk pass-backward forward-backward terakhir. Ini bisa berguna ketika berhadapan dengan a) model besar, dan/atau b) jumlah sub-batch yang tidak sepele.
Loss_kwargs - Argumen kata kunci tambahan untuk fungsi kerugian loss_fn . Ini dimaksudkan untuk mengaktifkan perhitungan kerugian yang fleksibel (berkat grafik dinamis di pytorch) seperti reduksi, pembobotan, dll. Potensi, menggunakan loss_kwargs Anda dapat memasukkan output dari model encoder yang tidak dilacak oleh cache.
Return - Loss, Tensor Scaler Loss Loss Langkah Saat Ini (terlepas dari grafik).
model(x)model(*x)model(**x)model(*x[0], **x[1])Input generik lainnya tidak sepenuhnya didukung, kami melakukan panggilan model menggunakan heuristik berikut,
model(*x)model(**x)model(*x[0], **x[1]) Untuk menjalankannya, split_input_fn harus ditentukan selama inisialisasi cache untuk memecah input ini menjadi batch yang lebih kecil. Dalam beberapa kasus yang jarang, Anda mungkin juga perlu mengganti get_input_tensors ketika heuristiknya tidak dapat mengambil cukup tensor yang mencakup semua perangkat CUDA yang menyimpan beberapa tensor dalam input.
Katakanlah kami ingin mempelajari ruang embedding label dan teks. Pertimbangkan empat pasangan berikut. (Dalam praktiknya, Anda akan memiliki lebih banyak entri teks dan lebih lama.)
labels = ['fruit', 'meat', 'school', 'company']
texts = [
'this is an apple',
'steak should be cooked medium rare',
'cmu is pittsburgh',
'apple sells laptop'
]
Inisialisasi model encoder kami,
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
encoder1 = AutoModel.from_pretrained("bert-base-uncased").cuda()
encoder2 = AutoModel.from_pretrained("bert-base-uncased").cuda()
Inisialisasi objek lulusan,
from grad_cache import GradCache
from grad_cache.loss import SimpleContrastiveLoss
loss_fn = SimpleContrastiveLoss()
gc = GradCache(
models=[encoder1, encoder2],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
Di sini kita menggunakan argumen get_rep_fn untuk menentukan fungsi yang mengambil output model pelukan generik dan mengembalikan tensor representasi aktual.
Buat input model,
xx = tokenizer(tt, return_tensors='pt', padding=True)
yy = tokenizer(tt2, return_tensors='pt', padding=True)
Jalankan langkah cache,
gc(xx, yy, reduction='mean')
Di sini kita menggunakan reduction='mean' sebagai loss_kwargs untuk mengendalikan perilaku kehilangan. Dengan optimizer yang ditentukan, pembaruan gradien penuh dapat dilakukan sebagai,
optimizer.zero_grad()
gc(xx, yy, reduction='mean')
optimizer.step()
Ini ditangani secara alami oleh grafik dinamis (keajaiban). Anda meneruskan salinan dangkal dari model enkoder yang sama ke metode lulusan lulusan.
tied_encoder = AutoModel.from_pretrained("bert-base-uncased").cuda()
gc = GradCache(
models=[tied_encoder , tied_encoder],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
Di bawah kap, kait yang berbeda akan terdaftar untuk membuat perhitungan gradien yang benar.
Kami mengharapkan komunikasi lintas proses representasi ditangani oleh loss_fn .
from grad_cache.loss import DistributedContrastiveLoss
loss_fn_dist = DistributedContrastiveLoss()
Bungkus dengan benar model enkoder untuk pengurangan gradien,
encoder1_ddp = DistributedDataParallel(
encoder1, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
encoder2_ddp = DistributedDataParallel(
encoder2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
Anda dapat menginisialisasi cache menggunakan kerugian terdistribusi dan model DDP,
gc = GradCache(
models=[encoder1_ddp, encoder2_ddp],
chunk_sizes=2,
loss_fn=loss_fn_dist,
get_rep_fn=lambda v: v.pooler_output
)
Jalankan langkah cache,
gc(xx, yy, no_sync_except_last=True, reduction='mean')
Setel no_sync_except_last=True untuk menghindari pengurangan gradien yang tidak perlu.
Jika Anda mengembangkan proyek baru, kami sarankan juga memeriksa dekorator yang telah kami berikan untuk membuat fungsi pesanan lebih tinggi untuk cache.
grad_cache.functional.cached(func: Callable[..., Tensor])
Dekorator yang mengambil fungsi panggilan model ke dalam versi yang kompatibel dengan cache.
FUNC - Fungsi yang memanggil model dan mengembalikan Tensor Representasi.
Return - Fungsi yang mengembalikan 1) Representasi tensor daun untuk konstruksi cache, 2) Fungsi penutupan untuk ke depan ke -2 dan yang di -cache ke belakang. Hubungi 2) dengan 1) sebagai argumen setelah menelepon ke belakang pada Tensor kerugian.
grad_cache.functional.cat_input_tensor(func: Callable[..., Tensor])
Dekorator yang menggabungkan argumen posisi dan kata kunci dari daftar tipe [tensor] menjadi satu tensor pada dimensi ke -0. Ini bisa berguna berurusan dengan hasil tensor representasi dari beberapa cache ke depan.
func - fungsi kerugian
Return - Fungsi kerugian yang dihiasi untuk hasil yang di -cache.
grad_cache.functional.gather_input_tensor(func: Callable[..., Tensor], axis=0)
Dekorator yang merupakan argumen posisi dan kata kunci dari semua tuan dan menggabungkannya pada poros. Dimaksudkan untuk digunakan untuk menciptakan kehilangan pembelajaran kontras yang terdistribusi.
func - fungsi kerugian
Return - Fungsi kerugian yang dihiasi untuk pelatihan terdistribusi.
Dekorator fungsional sangat berguna jika pemuat data Anda memancarkan batch kecil, dari mana Anda dapat membangun batch besar. Katakanlah Anda juga ingin melakukan presisi campuran otomatis, pertama -tama kami mendefinisikan fungsi panggilan model dan fungsi kerugian,
from grad_cache.functional import cached, cat_input_tensor
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
@cached
@autocast()
def call_model(model, input):
return model(**input).pooler_output
@cat_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
return F.cross_entropy(scores, target=target)
Katakanlah Anda memiliki loader dataloader yang memancarkan kumpulan kecil tuple (xx, yy) dengan ukuran (m * n) dan Anda ingin berlatih dengan menggabungkan 16 batch kecil untuk mendapatkan batch (16m * 16n),
cache_x = []
cache_y = []
closures_x = []
closures_y = []
for step, sub_batch in enumerate(loader):
xx, yy = sub_batch
rx, cx = call_model(bert, xx)
ry, cy = call_model(bert, yy)
cache_x.append(rx)
cache_y.append(ry)
closuresx.append(cx)
closuresy.append(cy)
if (step + 1) % 16 == 0:
loss = contrastive_loss(cache_x, cache_y)
scaler.scale(loss).backward()
for f, r in zip(closuresx, cache_x):
f(r)
for f, r in zip(closuresy, cache_y):
f(r)
cache_x = []
cache_y = []
closures_x = []
closures_y = []
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Menjalankan pelatihan multi-proses yang didistribusikan membutuhkan: 1) (semua-) Mengumpulkan representasi di seluruh perangkat dan 2) (all-reduce) gradien di seluruh perangkat. Kedua langkah akan terjadi di luar funtions yang dihiasi yang di -cache.
Yang terakhir mudah dicapai dengan membungkus encoder, misalnya bert , di DistributedDataParallel .
bert = DistributedDataParallel(
bert, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
Yang pertama membutuhkan OP yang didistribusikan ekstra dalam fungsi kerugian, yang harus dilakukan sesuai dengan definisi kerugian asli. Misalnya,
from torch import distributed as dist
from grad_cache.functional import cat_input_tensor, gather_input_tensor
@cat_input_tensor
@gather_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
# scale the loss as DistributedDataParallel will do mean reduce
return F.cross_entropy(scores, target=target) * dist.get_world_size()
grad_cache/grad_cache.py - Tentukan kelas lulusan. Kode di bawah 300 baris termasuk komentar. Untuk pengembangan, kami mendorong Anda untuk membacanya.
grad_cache/functional.py - Tentukan dekorator untuk membuat fungsi urutan yang lebih tinggi untuk caching gradien dari fungsi panggilan model biasa dan fungsi kerugian.