O cache do gradiente é uma técnica simples para escalar ilimitada em lote de aprendizado contrastante muito além da restrição de memória GPU/TPU. Isso significa que o treinamento que costumava tomar hardware pesado, por exemplo, 8 V100 GPU, pode ser feito em uma única GPU. Além disso, o cache de gradiente permite que os usuários substituam o Big RAM GPU/TPU por sistemas de baixo flop com baixo custo.
Este repositório mantém uma implementação genérica do cache de gradiente descrito em nossa escala de papel, tamanho de lotes de aprendizado de aprendizado profundo sob a configuração limitada da memória. As estruturas Pytorch e Jax são suportadas.
@inproceedings{gao2021scaling,
title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
author={Luyu Gao, Yunyi Zhang, Jiawei Han, Jamie Callan},
booktitle ={Proceedings of the 6th Workshop on Representation Learning for NLP},
year={2021},
}
NOVO: Agora apoiamos Jax e TPU!
O cache do gradiente também foi integrado à recuperação de passagem densa (DPR). Confira nosso kit de ferramentas GC-DPR.
Primeiro, instale o back -end de aprendizado profundo desejado, Pytorch ou Jax. Para instalar o GradCache, clone este repositório e execute o PIP.
git clone https://github.com/luyug/GradCache
cd GradCache
pip install .
Para desenvolvimento,
pip install --editable .
As funcionalidades do cache de gradiente são implementadas na classe GradCache . Se você estiver desenvolvendo um novo projeto em vez de corrigir um antigo, também consulte nossa abordagem funcional para obter um esforço reduzido.
Para o usuário JAX/FLAX, dê uma olhada em uma função de trem simples aqui.
O método __init__ da classe define o cache e possui vários parâmetros funcionais *_fn para facilitar o ajuste dos comportamentos do modelo. Como alternativa, você também pode subclassificar o gradcache.
grad_cache.GradCache(
models: List[nn.Module],
chunk_sizes: Union[int, List[int]],
loss_fn: Callable[..., Tensor],
split_input_fn: Callable[[Any, int], Any] = None,
get_rep_fn: Callable[..., Tensor] = None,
fp16: bool = False,
scaler: GradScaler = None,
)
Modelos - Uma lista de modelos de codificadores a serem atualizados com o cache do gradiente.
chunk_sizes - um número inteiro indicando tamanho de pedaço. Ou uma lista de números inteiros de tamanho de bloco para cada modelo. Isso controla para cada modelo o tamanho do sub-lote para executar o passe para a frente e deve ser definido com base na memória GPU disponível. Um valor muito pequeno deixará a GPU abaixo utilizada.
LEST_FN - Uma função de perda que leva os tensores de representação de número igual ao número de modelos nos models e números arbitrários de argumentos de palavras -chave. Ele deve calcular a perda com base nos tensores de entrada e, em nenhum caso, modificar as relações dos tensores de entrada no gráfico do AutoGRAD, que posteriormente são confiadas para criar o cache do gradiente.
split_input_fn - uma função opcional que dividiu a entrada do modelo genérico em pedaços com base em chunk_sizes definidos. Se não for fornecido, esta classe tentará o melhor para dividir as entradas dos tipos suportados. Consulte a função split_inputs .
get_rep_fn - uma função opcional que leva os tensores de saída do modelo genérico e de retorno de representação. Se não for fornecido, a saída genérica é assumida como o tensor de representação.
FP16 - Se verdadeiro, execute o treinamento de precisão mista, que exige que o Scaler também seja definido.
Scaler - um objeto de gradscaler para treinamento automático de precisão mista.
Para executar uma etapa de computação gradiente em cache, ligue para a função cache_step ,
cache_step(
*model_inputs,
no_sync_except_last: bool = False,
**loss_kwargs
)
Execute uma etapa de cache de gradiente único. Após o retorno da função, as atualizações são calculadas para cada modelo em self.models com gradiente preenchido nos pesos, como se os model_inputs fossem executados como um enorme lote único em hardware suficientemente grande. Chamar um objeto GradCache com __call__ também invocará esta função.
Model_inputs - Lista de entradas para cada modelo de codificador. Deve estar em ordem semelhante à self.models .
NO_SYNC_EXCECT_LAST -Se true, em configuração distribuída, para cada modelo, apenas aciona a redução do gradiente entre os processos para o último passe para a frente do último sub-lote. Isso pode ser útil ao lidar com um modelo grande e/ou b) número não trivial de sub-lotes.
LEST_KWARGS - Argumentos adicionais de palavras -chave para a função de perda loss_fn . Isso se destina a habilitar a computação de perda flexível (graças ao gráfico dinâmico em pytorch), como redução, ponderação, etc. Potencialmente, usando loss_kwargs , você pode incorporar saídas daqueles modelos de codificadores não rastreados pelo cache.
Retorno - Perda, as etapas atuais perdem o tensor Scaler (destacado do gráfico).
model(x)model(*x)model(**x)model(*x[0], **x[1])Outras entradas genéricas não são totalmente suportadas, realizamos chamadas de modelo usando as seguintes heurísticas,
model(*x)model(**x)model(*x[0], **x[1]) Para executar com eles, split_input_fn deve ser especificado durante a inicialização do cache para dividir essas entradas em lotes menores. Em alguns casos raros, você também pode precisar substituir get_input_tensors quando sua heurística não pode pegar tensores suficientes que cobrem todos os dispositivos CUDA que mantêm alguns tensores na entrada.
Digamos que queremos aprender um espaço de incorporação de rótulos e texto. Considere os quatro pares a seguir. (Na prática, você terá muito mais e muito mais longas entradas de texto.)
labels = ['fruit', 'meat', 'school', 'company']
texts = [
'this is an apple',
'steak should be cooked medium rare',
'cmu is pittsburgh',
'apple sells laptop'
]
Inicialize nossos modelos de codificadores,
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
encoder1 = AutoModel.from_pretrained("bert-base-uncased").cuda()
encoder2 = AutoModel.from_pretrained("bert-base-uncased").cuda()
Inicialize o objeto GradCache,
from grad_cache import GradCache
from grad_cache.loss import SimpleContrastiveLoss
loss_fn = SimpleContrastiveLoss()
gc = GradCache(
models=[encoder1, encoder2],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
Aqui, usamos o argumento get_rep_fn para especificar uma função que leva a saída genérica do modelo Huggingface e retorne o tensor de representação real.
Criar entrada de modelo,
xx = tokenizer(tt, return_tensors='pt', padding=True)
yy = tokenizer(tt2, return_tensors='pt', padding=True)
Execute uma etapa de cache,
gc(xx, yy, reduction='mean')
Aqui, usamos reduction='mean' como um perda_kwargs para controlar o comportamento da perda. Com um optimizer definido, a atualização completa do gradiente pode ser feita como,
optimizer.zero_grad()
gc(xx, yy, reduction='mean')
optimizer.step()
Isso é naturalmente tratado pelo (magia do) gráfico dinâmico. Você passa cópias rasas do mesmo modelo de codificador para o método GradCache Init.
tied_encoder = AutoModel.from_pretrained("bert-base-uncased").cuda()
gc = GradCache(
models=[tied_encoder , tied_encoder],
chunk_sizes=2,
loss_fn=loss_fn,
get_rep_fn=lambda v: v.pooler_output
)
Sob o capô, ganchos distintos serão registrados para fazer computação de gradiente correta.
Esperamos que a comunicação entre processos de representações seja tratada pelo loss_fn .
from grad_cache.loss import DistributedContrastiveLoss
loss_fn_dist = DistributedContrastiveLoss()
Enrole adequadamente os modelos do codificador para redução de gradiente,
encoder1_ddp = DistributedDataParallel(
encoder1, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
encoder2_ddp = DistributedDataParallel(
encoder2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
Você pode inicializar o cache usar a perda distribuída e os modelos DDP,
gc = GradCache(
models=[encoder1_ddp, encoder2_ddp],
chunk_sizes=2,
loss_fn=loss_fn_dist,
get_rep_fn=lambda v: v.pooler_output
)
Execute uma etapa de cache,
gc(xx, yy, no_sync_except_last=True, reduction='mean')
Defina no_sync_except_last=True para evitar redução de gradiente desnecessário.
Se você estiver desenvolvendo um novo projeto, recomendamos também verificar os decoradores que fornecemos para criar funções de ordem superior para o cache.
grad_cache.functional.cached(func: Callable[..., Tensor])
Um decorador que leva uma função de chamada de modelo em uma versão compatível em cache.
FUNC - Uma função que chama o modelo e retorna o tensor de representação.
Retorno - uma função que retorna 1) Tensores de folhas de representação para construção de cache, 2) uma função de fechamento para o 2º dianteiro e o cache para trás. Ligue para 2) com 1) como argumento depois de ligar para o tensor de perda.
grad_cache.functional.cat_input_tensor(func: Callable[..., Tensor])
Um decorador que concatena argumentos posicionais e de palavras -chave da lista de tipos [tensor] em um único tensor na 0ª dimensão. Isso pode ser útil lidar com os resultados dos tensores de representação de vários cache para a frente.
functão - uma função de perda
Retorno - Função de perda decorada para resultados em cache.
grad_cache.functional.gather_input_tensor(func: Callable[..., Tensor], axis=0)
Um decorador que todo o versátil e a palavra-chave argumentos do tipo tensor e os concatenam no eixo. Destinado a ser usado para criar perda de aprendizado contrastivo distribuído.
functão - uma função de perda
Retorno - Função de perda decorada para treinamento distribuído.
Os decoradores funcionais são particulares se o seu carregador de dados estiver emitindo pequenos lotes, dos quais você pode construir o grande lote. Digamos que você também deseja fazer precisão mista automática, primeiro definimos a função de chamada e a função de perda de modelos,
from grad_cache.functional import cached, cat_input_tensor
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
@cached
@autocast()
def call_model(model, input):
return model(**input).pooler_output
@cat_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
return F.cross_entropy(scores, target=target)
Digamos que você tenha um loader de dataloader emitindo pequenos lotes de tupla (xx, yy) de tamanho (m * n) e que deseja treinar agregando 16 pequenos lotes para obter um lote de (16m * 16n),
cache_x = []
cache_y = []
closures_x = []
closures_y = []
for step, sub_batch in enumerate(loader):
xx, yy = sub_batch
rx, cx = call_model(bert, xx)
ry, cy = call_model(bert, yy)
cache_x.append(rx)
cache_y.append(ry)
closuresx.append(cx)
closuresy.append(cy)
if (step + 1) % 16 == 0:
loss = contrastive_loss(cache_x, cache_y)
scaler.scale(loss).backward()
for f, r in zip(closuresx, cache_x):
f(r)
for f, r in zip(closuresy, cache_y):
f(r)
cache_x = []
cache_y = []
closures_x = []
closures_y = []
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
O treinamento em vários processos distribuídos exige: 1) (All-) Reúna representações entre os dispositivos e 2) gradientes (tudo de redução) entre os dispositivos. Ambas as etapas acontecerão fora das funções decoradas em cache.
Este último é fácil de alcançar, envolvendo os codificadores, por exemplo, um bert , em DistributedDataParallel .
bert = DistributedDataParallel(
bert, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
O primeiro requer operações extras distribuídas na função de perda, que deve ser feita de acordo com a definição de perda original. Por exemplo,
from torch import distributed as dist
from grad_cache.functional import cat_input_tensor, gather_input_tensor
@cat_input_tensor
@gather_input_tensor
@autocast()
def contrastive_loss(x, y):
target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
scores = torch.matmul(x, y.transpose(0, 1))
# scale the loss as DistributedDataParallel will do mean reduce
return F.cross_entropy(scores, target=target) * dist.get_world_size()
grad_cache/grad_cache.py - defina a classe GradCache. O código está abaixo de 300 linhas, incluindo comentários. Para o desenvolvimento, incentivamos você a ler isso.
grad_cache/funcional.py - Defina os decoradores para criar função de ordem superior para o cache de gradiente a partir de funções e funções de chamadas comuns do modelo.