El caché de gradiente es una técnica simple para un lote de aprendizaje contrastante de escala ilimitada mucho más allá de la restricción de memoria GPU/TPU. Esto significa que el entrenamiento que solía tomar hardware pesado, por ejemplo, 8 GPU V100, se puede hacer en una sola GPU. Además, el caché de gradiente permite a los usuarios reemplazar la GPU/TPU de Big RAM con sistemas de ram bajo de alto flop mucho más rentable.
Este repositorio posee una implementación genérica de la caché de gradiente descrita en nuestro tamaño de lote de aprendizaje de aprendizaje de ampliación de papel en papel bajo la configuración limitada de memoria. Se admiten los marcos Pytorch y Jax.
@inproceedings{gao2021scaling,
title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
author={Luyu Gao, Yunyi Zhang, Jiawei Han, Jamie Callan},
booktitle ={Proceedings of the 6th Workshop on Representation Learning for NLP},
year={2021},
}
NUEVO: ¡Ahora apoyamos a Jax y TPU!
El caché de gradiente también se ha integrado en una densa recuperación de pasos (DPR). Consulte nuestro kit de herramientas GC-DPR.
Primero instale el backend de aprendizaje profundo deseado, ya sea Pytorch o Jax. Para instalar GradCache, clone este repositorio y ejecute PIP.
git clone https://github.com/luyug/GradCache
cd GradCache
pip install .
Para el desarrollo,
pip install --editable .
Las funcionalidades de almacenamiento en caché de gradiente se implementan en la clase GradCache . Si está desarrollando un nuevo proyecto en lugar de parchear uno antiguo, también consulte nuestro enfoque funcional para un enfoque reducido de esfuerzo.
Para el usuario de Jax/Flax, eche un vistazo a una función de tren simple aquí.
El método __init__ de la clase define el caché y tiene varios parámetros funcionales *_fn para un fácil ajuste de los comportamientos del modelo. Alternativamente, también puede subclase GradCache.
grad_cache.GradCache(
models: List[nn.Module],
chunk_sizes: Union[int, List[int]],
loss_fn: Callable[..., Tensor],
split_input_fn: Callable[[Any, int], Any] = None,
get_rep_fn: Callable[..., Tensor] = None,
fp16: bool = False,
scaler: GradScaler = None,
)
Modelos : una lista de modelos de codificadores para actualizarse con el caché de gradiente.
Chunk_sizes : un entero que indica el tamaño del fragmento. O una lista de enteros de tamaño de fragmento para cada modelo. Esto controla para cada modelo el tamaño de la sub-lotes para ejecutar el pase hacia adelante hacia adelante y debe establecerse en función de la memoria de GPU disponible. Un valor demasiado pequeño dejará la GPU bajo utilizada.
Loss_fn : una función de pérdida que toma tensores de representación del número igual al número de modelos en models y números arbitrarios de argumentos de palabras clave. Debe calcular la pérdida según los tensores de entrada, y en ningún caso modifique las relaciones de los tensores de entrada en el gráfico Autograd, que luego se basan para crear el caché de gradiente.
split_input_fn : una función opcional que divide la entrada del modelo genérico en fragmentos basados en Chunk_Sizes definidos. Si no se proporciona, esta clase hará todo lo posible para dividir las entradas de los tipos compatibles. Ver la función split_inputs .
get_rep_fn : una función opcional que toma tensores de representación de salida de modelo genérico y retorno. Si no se proporciona, se supone que la salida genérica es el tensor de representación.
FP16 : si es cierto, ejecute entrenamiento de precisión mixta, que requiere que Scaler también se establezca.
escalador : un objeto graduador para entrenamiento automático de precisión mixta.
Para ejecutar un paso de computación de gradiente en caché, llame a la función cache_step ,
cache_step(
*model_inputs,
no_sync_except_last: bool = False,
**loss_kwargs
)
Ejecute un paso de caché de gradiente único. Tras el retorno de la función model_inputs las actualizaciones se calculan para cada modelo en self.models . Llamar a un objeto GradCache con __call__ también invocará esta función.
Model_Inputs : lista de entradas a cada modelo de codificador. Debe estar en orden similar al self.models .
NO_SYNC_EXCET_LAST -Si es verdadero, en la configuración distribuida, para cada modelo, solo desencadena la reducción del gradiente en los procesos para el pase hacia adelante hacia adelante por la última subpasa. Esto podría ser útil cuando se trata con a) modelo grande, y/o b) número no trivial de subteres.
Loss_kwargs : argumentos de palabras clave adicionales a la función de pérdida loss_fn . Esto está destinado a permitir el cálculo de pérdida flexible (gracias al gráfico dinámico en Pytorch), como la reducción, la ponderación, etc. potencialmente, utilizando loss_kwargs puede incorporar salidas de esos modelos de codificadores que no rastrean el caché.
Retorno - Pérdida, el tensor de escalador de pérdida de pasos actual (separado del gráfico).
model(x)model(*x)model(**x)model(*x[0], **x[1])Otras entradas genéricas no son totalmente compatibles, realizamos llamadas de modelo utilizando las siguientes heurísticas,
model(*x)model(**x)model(*x[0], **x[1]) Para ejecutarse con ellos, split_input_fn debe especificarse durante la inicialización de caché para dividir estas entradas en lotes más pequeños. En algunos casos raros, es posible que también deba anular get_input_tensors cuando su heurística no puede agarrar suficientes tensores que cubren todos los dispositivos CUDA que contienen algunos tensores en la entrada.
Digamos que queremos aprender un espacio de incrustación de etiquetas y texto. Considere los siguientes cuatro pares. (En la práctica, tendrás muchas más y mucho más largas entradas de texto).
labels = ['fruit', 'meat', 'school', 'company']
texts = [
'this is an apple',
'steak should be cooked medium rare',
'cmu is pittsburgh',
'apple sells laptop'
]
Inicializar nuestros modelos de codificadores,
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
encoder1 = AutoModel.from_pretrained("bert-base-uncased").cuda()
encoder2 = AutoModel.from_pretrained("bert-base-uncased").cuda()
Inicializar el objeto GradCache,
from grad_cache import GradCache
from grad_cache.loss import SimpleContrastiveLoss
loss_fn = SimpleContrastiveLoss()
gc = GradCache(
models=[encoder1, encoder2],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
Aquí usamos el argumento get_rep_fn para especificar una función que toma la salida del modelo genérico de la cara de abrazo y devuelve el tensor de representación real.
Crear entrada de modelo,
xx = tokenizer(tt, return_tensors='pt', padding=True)
yy = tokenizer(tt2, return_tensors='pt', padding=True)
Ejecute un paso de caché,
gc(xx, yy, reduction='mean')
Aquí usamos reduction='mean' como una pérdida_kwargs para controlar el comportamiento de pérdida. Con un optimizer definido, la actualización de gradiente completa se puede hacer como,
optimizer.zero_grad()
gc(xx, yy, reduction='mean')
optimizer.step()
Esto se maneja naturalmente por el gráfico (magia de) dinámico. Pase copias superficiales del mismo modelo de codificador al método de inicio de GradCache.
tied_encoder = AutoModel.from_pretrained("bert-base-uncased").cuda()
gc = GradCache(
models=[tied_encoder , tied_encoder],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
Debajo del capó, se registrarán ganchos distintos para hacer un cálculo de gradiente correcto.
Esperamos que el proceso cruzado de la comunicación de representaciones sea manejada por el loss_fn .
from grad_cache.loss import DistributedContrastiveLoss
loss_fn_dist = DistributedContrastiveLoss()
Envuelva correctamente los modelos de codificadores para la reducción de gradiente,
encoder1_ddp = DistributedDataParallel(
encoder1, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
encoder2_ddp = DistributedDataParallel(
encoder2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
Puede inicializar el caché use la pérdida distribuida y los modelos DDP,
gc = GradCache(
models=[encoder1_ddp, encoder2_ddp],
chunk_sizes=2,
loss_fn=loss_fn_dist,
get_rep_fn=lambda v: v.pooler_output
)
Ejecute un paso de caché,
gc(xx, yy, no_sync_except_last=True, reduction='mean')
Establecer no_sync_except_last=True para evitar la reducción innecesaria del gradiente.
Si está desarrollando un nuevo proyecto, le recomendamos que también revise los decoradores que hemos proporcionado para crear funciones de orden superior para el caché.
grad_cache.functional.cached(func: Callable[..., Tensor])
Un decorador que toma una función de llamada modelo en una versión compatible en caché.
FUNC : una función que llama al modelo y retrocede el tensor de representación.
Retorno : una función que devuelve 1) Tensores de hoja de representación para la construcción de caché, 2) una función de cierre para el segundo delantero y el almacenado en caché hacia atrás. Llame 2) con 1) como argumento después de llamar al tensor de pérdida.
grad_cache.functional.cat_input_tensor(func: Callable[..., Tensor])
Un decorador que concatena los argumentos posicionales y de palabras clave de la lista de tipos [tensor] en un solo tensor en la dimensión 0. Esto puede ser útil para tratar los resultados de tensores de representación de múltiples caché hacia adelante.
FUNC : una función de pérdida
Retorno : función de pérdida decorada para resultados en caché.
grad_cache.functional.gather_input_tensor(func: Callable[..., Tensor], axis=0)
Un decorador que todos los argumentos posicionales y de palabras clave del tensor de tipo y los concatenan en el eje. Destinado a ser utilizado para crear pérdida de aprendizaje contrastante distribuido.
FUNC : una función de pérdida
Retorno : función de pérdida decorada para capacitación distribuida.
Los decoradores funcionales son particularmente útiles si su cargador de datos está emitiendo pequeños lotes, desde los cuales puede construir el lote grande. Digamos que también desea hacer una precisión mixta automática, primero definimos la función de llamada del modelo y la función de pérdida,
from grad_cache.functional import cached, cat_input_tensor
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
@cached
@autocast()
def call_model(model, input):
return model(**input).pooler_output
@cat_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
return F.cross_entropy(scores, target=target)
Digamos que tiene un loader de datos que emite pequeños lotes de tuple (xx, yy) de tamaño (m * n) y que desea entrenar agregando 16 pequeños lotes para obtener un lote de (16m * 16n),
cache_x = []
cache_y = []
closures_x = []
closures_y = []
for step, sub_batch in enumerate(loader):
xx, yy = sub_batch
rx, cx = call_model(bert, xx)
ry, cy = call_model(bert, yy)
cache_x.append(rx)
cache_y.append(ry)
closuresx.append(cx)
closuresy.append(cy)
if (step + 1) % 16 == 0:
loss = contrastive_loss(cache_x, cache_y)
scaler.scale(loss).backward()
for f, r in zip(closuresx, cache_x):
f(r)
for f, r in zip(closuresy, cache_y):
f(r)
cache_x = []
cache_y = []
closures_x = []
closures_y = []
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Ejecutar la capacitación distribuida de procesos múltiples requiere: 1) (All) recopilar representaciones entre dispositivos y 2) (All-Reduce) Gradientes en los dispositivos. Ambos pasos ocurrirán fuera de las funciones decoradas en caché.
Este último es fácil de lograr envolviendo codificadores, por ejemplo, un bert , en DistributedDataParallel .
bert = DistributedDataParallel(
bert, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
El primero requiere OPS distribuidos adicionales en la función de pérdida, que debe hacerse de acuerdo con la definición de pérdida original. Por ejemplo,
from torch import distributed as dist
from grad_cache.functional import cat_input_tensor, gather_input_tensor
@cat_input_tensor
@gather_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
# scale the loss as DistributedDataParallel will do mean reduce
return F.cross_entropy(scores, target=target) * dist.get_world_size()
Grad_cache/grad_cache.py: defina la clase GradCache. El código está debajo de 300 líneas, incluidos los comentarios. Para el desarrollo, lo alentamos a que lo lea.
Grad_cache/funcional.py: defina los decoradores para crear una función de orden superior para el almacenamiento en caché de gradiente a partir de funciones de llamadas de modelos ordinarios y funciones de pérdida.