แคชไล่ระดับสีเป็นเทคนิคที่ง่ายสำหรับการปรับขนาดการเรียนรู้แบบไม่ จำกัด ที่เกินกว่าข้อ จำกัด หน่วยความจำ GPU/TPU ซึ่งหมายถึงการฝึกอบรมที่ใช้ในการใช้ฮาร์ดแวร์หนักเช่น 8 V100 GPU สามารถทำได้ใน GPU เดียว นอกจากนี้แคชไล่ระดับสีช่วยให้ผู้ใช้สามารถแทนที่ Big Ram GPU/TPU ด้วยระบบ RAM ต่ำที่มีประสิทธิภาพสูงมาก
repo นี้มีการใช้งานทั่วไปของแคชการไล่ระดับสีที่อธิบายไว้ในการปรับขนาดกระดาษของเราขนาดการเรียนรู้ที่แตกต่างกันอย่างลึกซึ้งภายใต้การตั้งค่าหน่วยความจำ จำกัด รองรับทั้งเฟรมเวิร์ก Pytorch และ Jax
@inproceedings{gao2021scaling,
title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
author={Luyu Gao, Yunyi Zhang, Jiawei Han, Jamie Callan},
booktitle ={Proceedings of the 6th Workshop on Representation Learning for NLP},
year={2021},
}
ใหม่: ตอนนี้เราสนับสนุน Jax และ TPU!
แคชไล่ระดับสีได้ถูกรวมเข้ากับการดึงข้อมูลที่หนาแน่น (DPR) ชำระเงินชุดเครื่องมือ GC-DPR ของเรา
ก่อนอื่นติดตั้งแบ็กเอนด์การเรียนรู้ลึกที่คุณต้องการไม่ว่าจะเป็น Pytorch หรือ Jax ในการติดตั้ง gradcache ให้โคลน repo นี้และเรียกใช้ pip
git clone https://github.com/luyug/GradCache
cd GradCache
pip install .
เพื่อการพัฒนา
pip install --editable .
ฟังก์ชั่นการแคชไล่ระดับสีถูกนำมาใช้ในคลาส GradCache หากคุณกำลังพัฒนา โครงการใหม่ แทนการแก้ไขโครงการเก่าให้ชำระเงินวิธีการทำงานของเราเพื่อลดความพยายาม
สำหรับผู้ใช้ JAX/FLAX ลองดูฟังก์ชั่นรถไฟง่าย ๆ ที่นี่
วิธี __init__ ของคลาสกำหนดแคชและมีพารามิเตอร์การทำงานหลายอย่าง *_fn เพื่อการปรับพฤติกรรมแบบจำลองได้ง่าย หรือคุณยังสามารถ GradCache คลาสย่อยได้
grad_cache.GradCache(
models: List[nn.Module],
chunk_sizes: Union[int, List[int]],
loss_fn: Callable[..., Tensor],
split_input_fn: Callable[[Any, int], Any] = None,
get_rep_fn: Callable[..., Tensor] = None,
fp16: bool = False,
scaler: GradScaler = None,
)
โมเดล - รายการรุ่นเข้ารหัสที่จะอัปเดตด้วยแคชไล่ระดับสี
chunk_sizes - จำนวนเต็มที่ระบุขนาดก้อน หรือรายการจำนวนเต็มขนาดก้อนสำหรับแต่ละรุ่น การควบคุมนี้สำหรับแต่ละรุ่นขนาดชุดย่อยเพื่อเรียกใช้บัตรผ่านไปข้างหน้ากลับไปข้างหน้าและควรตั้งค่าตามหน่วยความจำ GPU ที่มีอยู่ ค่าเล็กเกินไปจะออกจาก GPU ภายใต้การใช้งาน
Loss_fn - ฟังก์ชั่นการสูญเสียที่ใช้เทนเซอร์ที่เป็นตัวแทนของจำนวนเท่ากับจำนวนโมเดลใน models และจำนวนอาร์กิวเมนต์คำหลักโดยพลการ มันควรคำนวณการสูญเสียตามเทนเซอร์อินพุตและในกรณีที่ไม่มีการแก้ไขความสัมพันธ์ของเทนเซอร์อินพุตในกราฟ Autograd ซึ่งต่อมาอาศัยการสร้างแคชไล่ระดับสี
STIFL_INPUT_FN - ฟังก์ชั่นเสริมที่แยกอินพุตโมเดลทั่วไปเป็นชิ้น ๆ ตาม chunk_sizes ที่กำหนดไว้ หากไม่ได้จัดเตรียมคลาสนี้จะพยายามอย่างเต็มที่เพื่อแยกอินพุตของประเภทที่รองรับ ดูฟังก์ชัน split_inputs
GET_REP_FN - ฟังก์ชั่นเสริมที่ใช้เอาต์พุตแบบจำลองทั่วไปและเทนเซอร์การเป็นตัวแทนกลับ หากไม่ได้ให้มาแล้วเอาต์พุตทั่วไปจะถือว่าเป็นเทนเซอร์ที่เป็นตัวแทน
FP16 - ถ้าเป็นจริงให้ใช้การฝึกอบรมแบบผสมผสานแบบผสมซึ่งต้องการให้ Scaler ถูกตั้งค่าด้วย
Scaler - วัตถุ gratscaler สำหรับการฝึกอบรมความแม่นยำแบบผสมอัตโนมัติ
ในการเรียกใช้ขั้นตอนการไล่ระดับสีแบบแคชให้เรียกใช้ฟังก์ชัน cache_step
cache_step(
*model_inputs,
no_sync_except_last: bool = False,
**loss_kwargs
)
เรียกใช้ขั้นตอนแคชการไล่ระดับสีเดียว เมื่อฟังก์ชั่นส่งคืนการอัปเดตจะถูกคำนวณสำหรับแต่ละรุ่นใน self.models ที่มีการไล่ระดับสีที่มีน้ำหนักมากขึ้นราวกับว่า model_inputs ทำงานเป็นชุดเดี่ยวขนาดใหญ่บนฮาร์ดแวร์ขนาดใหญ่เพียงพอ การเรียกวัตถุ GradCache ด้วย __call__ จะเรียกใช้ฟังก์ชันนี้ด้วย
model_inputs - รายการอินพุตไปยังแต่ละรุ่นตัวเข้ารหัส ควรอยู่ในลำดับที่คล้ายกันกับ self.models
no_sync_except_last- ถ้าเป็นจริงภายใต้การตั้งค่าแบบกระจายสำหรับแต่ละรุ่นให้ลดการไล่ระดับสีให้กับกระบวนการสำหรับการส่งต่อไปข้างหน้าของชุดย่อยล่าสุด สิ่งนี้อาจมีประโยชน์เมื่อต้องรับมือกับ a) โมเดลขนาดใหญ่และ/หรือ B) จำนวนชุดย่อยที่ไม่สำคัญ
LOSCE_KWARGS - อาร์กิวเมนต์คำหลักเพิ่มเติมสำหรับฟังก์ชั่นการสูญเสีย loss_fn สิ่งนี้มีวัตถุประสงค์เพื่อเปิดใช้งานการคำนวณการสูญเสียที่ยืดหยุ่น (ขอบคุณกราฟแบบไดนามิกใน pytorch) เช่นการลดน้ำหนัก ฯลฯ อาจใช้ loss_kwargs คุณสามารถรวมเอาต์พุตจากรุ่นเข้ารหัสที่ไม่ได้ติดตามโดยแคช
ผลตอบแทน - การสูญเสียขั้นตอนปัจจุบันการสูญเสียเทนเซอร์ (แยกออกจากกราฟ)
model(x)model(*x)model(**x)model(*x[0], **x[1])อินพุตทั่วไปอื่น ๆ ไม่ได้รับการสนับสนุนอย่างเต็มที่เราทำการโทรแบบจำลองโดยใช้ฮิวริสติกต่อไปนี้
model(*x)model(**x)model(*x[0], **x[1]) ในการทำงานกับพวกเขาควรระบุ split_input_fn ระหว่างการเริ่มต้นแคชเพื่อแบ่งอินพุตเหล่านี้ออกเป็นแบทช์ขนาดเล็ก ในบางกรณีที่หายากคุณอาจต้องแทนที่ get_input_tensors เมื่อฮิวริสติกของมันไม่สามารถคว้าเทนเซอร์ได้เพียงพอที่ครอบคลุมอุปกรณ์ CUDA ทั้งหมดที่เก็บเทนเซอร์บางตัวในอินพุต
สมมติว่าเราต้องการเรียนรู้พื้นที่ฝังตัวของป้ายกำกับและข้อความ พิจารณาสี่คู่ต่อไปนี้ (ในทางปฏิบัติคุณจะมีรายการข้อความที่ยาวขึ้นเรื่อย ๆ )
labels = ['fruit', 'meat', 'school', 'company']
texts = [
'this is an apple',
'steak should be cooked medium rare',
'cmu is pittsburgh',
'apple sells laptop'
]
เริ่มต้นรุ่นเข้ารหัสของเรา
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
encoder1 = AutoModel.from_pretrained("bert-base-uncased").cuda()
encoder2 = AutoModel.from_pretrained("bert-base-uncased").cuda()
เริ่มต้นวัตถุ GradCache
from grad_cache import GradCache
from grad_cache.loss import SimpleContrastiveLoss
loss_fn = SimpleContrastiveLoss()
gc = GradCache(
models=[encoder1, encoder2],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
ที่นี่เราใช้อาร์กิวเมนต์ get_rep_fn เพื่อระบุฟังก์ชั่นที่ใช้เอาต์พุตโมเดล HuggingFace ทั่วไปและส่งคืนเทนเซอร์การเป็นตัวแทนจริง
สร้างอินพุตโมเดล
xx = tokenizer(tt, return_tensors='pt', padding=True)
yy = tokenizer(tt2, return_tensors='pt', padding=True)
เรียกใช้ขั้นตอนแคช
gc(xx, yy, reduction='mean')
ที่นี่เราใช้ reduction='mean' เป็น loss_kwargs เพื่อควบคุมพฤติกรรมการสูญเสีย ด้วย optimizer ที่กำหนดไว้การอัปเดตแบบเต็มรูปแบบสามารถทำได้เป็น
optimizer.zero_grad()
gc(xx, yy, reduction='mean')
optimizer.step()
นี่คือการจัดการตามธรรมชาติโดยกราฟ (เวทมนตร์) แบบไดนามิก คุณผ่านสำเนาตื้นของรุ่นเข้ารหัสเดียวกันไปยังวิธีการเริ่มต้นของ GradCache
tied_encoder = AutoModel.from_pretrained("bert-base-uncased").cuda()
gc = GradCache(
models=[tied_encoder , tied_encoder],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
ภายใต้ประทุนตะขอที่แตกต่างจะได้รับการลงทะเบียนเพื่อทำการคำนวณการไล่ระดับสีที่ถูกต้อง
เราคาดว่าการสื่อสารข้ามกระบวนการของการเป็นตัวแทนจะได้รับการจัดการโดย loss_fn
from grad_cache.loss import DistributedContrastiveLoss
loss_fn_dist = DistributedContrastiveLoss()
ห่อโมเดล ENCODER อย่างถูกต้องเพื่อลดการไล่ระดับสี
encoder1_ddp = DistributedDataParallel(
encoder1, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
encoder2_ddp = DistributedDataParallel(
encoder2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
คุณสามารถเริ่มต้นแคชใช้การสูญเสียแบบกระจายและรุ่น DDP
gc = GradCache(
models=[encoder1_ddp, encoder2_ddp],
chunk_sizes=2,
loss_fn=loss_fn_dist,
get_rep_fn=lambda v: v.pooler_output
)
เรียกใช้ขั้นตอนแคช
gc(xx, yy, no_sync_except_last=True, reduction='mean')
ตั้งค่า no_sync_except_last=True เพื่อหลีกเลี่ยงการลดการไล่ระดับสีที่ไม่จำเป็น
หากคุณกำลังพัฒนาโครงการใหม่เราขอแนะนำให้ตรวจสอบนักตกแต่งที่เราได้เตรียมไว้เพื่อสร้างฟังก์ชั่นการสั่งซื้อที่สูงขึ้นสำหรับแคช
grad_cache.functional.cached(func: Callable[..., Tensor])
มัณฑนากรที่ใช้ฟังก์ชั่นการโทรแบบจำลองในรุ่นที่เข้ากันได้แคช
func - ฟังก์ชั่นที่เรียกโมเดลและส่งคืนเทนเซอร์การเป็นตัวแทน
Return - ฟังก์ชั่นที่ส่งคืน 1) Tensors Leaf Tensors สำหรับการสร้างแคช 2) ฟังก์ชั่นการปิดสำหรับการส่งต่อที่ 2 และแคชไปข้างหลัง โทร 2) ด้วย 1) เป็นอาร์กิวเมนต์หลังจากโทรไปข้างหลังในการสูญเสียเทนเซอร์
grad_cache.functional.cat_input_tensor(func: Callable[..., Tensor])
มัณฑนากรที่เชื่อมต่อตำแหน่งและคำหลักอาร์กิวเมนต์ของรายการประเภท [เทนเซอร์] เป็นเทนเซอร์เดียวในมิติที่ 0 สิ่งนี้สามารถเกิดขึ้นได้อย่างมีประโยชน์กับผลลัพธ์ของเทนเซอร์ที่เป็นตัวแทนจากหลายแคชไปข้างหน้า
func - ฟังก์ชั่นการสูญเสีย
ผลตอบแทน - ฟังก์ชั่นการสูญเสียตกแต่งสำหรับผลลัพธ์ที่แคช
grad_cache.functional.gather_input_tensor(func: Callable[..., Tensor], axis=0)
มัณฑนากรที่มีการรวบรวมตำแหน่งและคำหลักอาร์กิวเมนต์ของ Type Tensor และต่อกันในแกน ตั้งใจที่จะใช้ในการสร้างการสูญเสียการเรียนรู้แบบกระจายแบบกระจาย
func - ฟังก์ชั่นการสูญเสีย
ผลตอบแทน - ฟังก์ชั่นการสูญเสียตกแต่งสำหรับการฝึกอบรมแบบกระจาย
นักตกแต่งที่ใช้งานได้นั้นมีประโยชน์เป็นพิเศษหากตัวโหลดข้อมูลของคุณกำลังเปล่งแบทช์ขนาดเล็กซึ่งคุณสามารถสร้างแบทช์ขนาดใหญ่ได้ สมมติว่าคุณต้องการทำความแม่นยำแบบผสมอัตโนมัติก่อนอื่นเราจะกำหนดฟังก์ชั่นการโทรแบบจำลองและฟังก์ชั่นการสูญเสีย
from grad_cache.functional import cached, cat_input_tensor
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
@cached
@autocast()
def call_model(model, input):
return model(**input).pooler_output
@cat_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
return F.cross_entropy(scores, target=target)
สมมติว่าคุณมี loader dataloader ที่เปล่งขนาดเล็กของ tuple (xx, yy) ของขนาด (m * n) และคุณต้องการฝึกอบรมโดยการรวม 16 ชุดเล็ก ๆ 16 ชุดเพื่อให้ได้ชุดของ (16m * 16n)
cache_x = []
cache_y = []
closures_x = []
closures_y = []
for step, sub_batch in enumerate(loader):
xx, yy = sub_batch
rx, cx = call_model(bert, xx)
ry, cy = call_model(bert, yy)
cache_x.append(rx)
cache_y.append(ry)
closuresx.append(cx)
closuresy.append(cy)
if (step + 1) % 16 == 0:
loss = contrastive_loss(cache_x, cache_y)
scaler.scale(loss).backward()
for f, r in zip(closuresx, cache_x):
f(r)
for f, r in zip(closuresy, cache_y):
f(r)
cache_x = []
cache_y = []
closures_x = []
closures_y = []
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
การฝึกอบรมแบบหลายกระบวนการแบบกระจายต้อง: 1) (All-) รวบรวมการเป็นตัวแทนในอุปกรณ์และ 2) การไล่ระดับสี (ลดทั้งหมด) ข้ามอุปกรณ์ ทั้งสองขั้นตอนจะเกิดขึ้น นอก การตกแต่งที่ได้รับการตกแต่ง
หลังเป็นเรื่องง่ายที่จะบรรลุโดยการห่อหุ้มตัวเข้ารหัสเช่น bert ใน DistributedDataParallel
bert = DistributedDataParallel(
bert, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
อดีตต้องการ OPS แบบกระจายพิเศษในฟังก์ชั่นการสูญเสียซึ่งควรทำตามคำจำกัดความการสูญเสียเดิม ตัวอย่างเช่น,
from torch import distributed as dist
from grad_cache.functional import cat_input_tensor, gather_input_tensor
@cat_input_tensor
@gather_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
# scale the loss as DistributedDataParallel will do mean reduce
return F.cross_entropy(scores, target=target) * dist.get_world_size()
grad_cache/grad_cache.py - กำหนดคลาส gradcache รหัสอยู่ภายใต้ 300 บรรทัดรวมถึงความคิดเห็น เพื่อการพัฒนาเราขอแนะนำให้คุณอ่านผ่าน
grad_cache/functional.py - กำหนดนักตกแต่งเพื่อสร้างฟังก์ชั่นการสั่งซื้อที่สูงขึ้นสำหรับการแคชไล่ระดับสีจากฟังก์ชั่นการโทรแบบจำลองทั่วไปและฟังก์ชั่นการสูญเสีย