Este é o código para reproduzir os experimentos do artigo EMNLP 2021 "A potência da escala para ajuste rápido com eficiência de parâmetro" (Lester et al., 2021).
Esses modelos são construídos no T5X, que define o modelo e o loop de treinamento; Flaxformer, que define o cálculo do modelo real; Linho, que define as camadas de modelo de baixo nível; e Jax, que fornece a execução real. Detalhes de nossa implementação podem ser encontrados aqui.
gs://{bucket-name}/path/to/item/in/bucket . É aqui que armazenaremos conjuntos de dados em cache, bem como pontos de verificação e resultados do modelo. Para facilitar a referência, alguns dos comandos mais comuns em nuvem para interagir com as VMs de TPU são # Create a Cloud TPU VM
$ gcloud alpha compute tpus tpu-vm create ${TPU_NAME}
--zone ${ZONE}
--accelerator-type v3-8
--version v2-alpha
# SSH into a Cloud TPU VM
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE}
# Delete a Cloud TPU VM
$ gcloud alpha compute tpus tpu-vm delete ${TPU_NAME} --zone ${ZONE}git clone --branch=main https://github.com/google-research/prompt-tuning
cd prompt-tuningpython3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html Se você enfrentar um erro em que o PIP tentar instalar as versões anteriores e preciosas das dependências (TensorFlow, por exemplo) até que ele tente instalar a versão 0.0.0 e depois falhar, tente adicionar --use-deprecated=legacy-resolver ao comando de instalação. Este erro está relacionado às versões necessárias entre as dependências e o comportamento é frequentemente chamado de trilha de volta. Se você usar o sinalizador, é possível que as versões incompatíveis das bibliotecas possam ser instaladas e você deve procurar avisos sobre incompatibilidades na saída do comando de instalação.
Nota: Se você planeja invadir os internos do ajuste rápido e precisar de uma instalação editável (para que as alterações no código clonado sejam usadas ao executar o treinamento), execute pip com o sinalizador -e e pode ser necessário excluir o arquivo pyproject.toml se estiver recebendo erros durante a instalação.
Para executar os testes, instale o pacote com a opção [test] (Python3 python3 -m pytest python3 -m pip install .[test] ...
Treinar um prompt é semelhante a ajustar um modelo com T5X; A principal diferença é que temos nosso próprio conjunto de arquivos de configuração de ajuste imediato a serem usados.
Fornecemos um script de demonstração ( prompt_tuning/scripts/sst2-demo.sh ) que possui todas as peças necessárias para o treinamento de um prompt. Você pode usá -lo como ponto de partida ou definir variáveis de ambiente MODEL_DIR e TFDS_DATA_DIR com caminhos para o seu balde de armazenamento em nuvem do Google para executar esse script diretamente.
./prompt-tuning/prompt_tuning/scripts/sst2-demo.shPara ajudar na velocidade da iteração, tendemos a especificar muito mais opções, a linha de comando, em vez de agrupar toda a configuração em um único arquivo Gin. Algumas opções de nota:
--gin_search_paths :: Uma lista separada por vírgula de diretórios para usar como prefixos de caminho para arquivos de gin. Podemos usar prompt_tuning.scripts.find_module ${module} para encontrar o local de instalação de bibliotecas que agrupam as configurações com elas.--gin_file :: o arquivo gin para carregar. Tendemos a usar os caminhos relativos começando com a biblioteca com os quais estão instalados, ou seja, prompt_tuning/configs/models/t5_1_1_base_prompt.gin sobre models/t5_1_1_base_prompt.gin para evitar qualquer confusão. O uso do sinalizador vários tempos pode ser usado para especificar vários arquivos de gin que serão mesclados. Quaisquer opções de configurações definidas em vários arquivos usarão o valor do último arquivo na lista.--gin.{PARAM}={VALUE} :: Este sinalizador de substituição geral definirá PARAM como VALUE . Isso pode ser usado para definir facilmente opções de configuração sem exigir que sejam argumentos reais da linha de comando. Por exemplo. --gin.utils.SaveCheckpointConfig.keep=20 salvará os últimos 20 pontos de verificação.À medida que os modelos aumentam, XL e XXL, por exemplo, eles não se encaixam nos 8 TPUs que vêm com uma única VM TPU. Nesses casos, precisaremos de uma fatia de um POD TPU (mais informações sobre a arquitetura da TPU e as configurações disponíveis podem ser encontradas aqui). A principal diferença entre o treinamento de um prompt em uma única TPU VM e em uma fatia de pod é que agora temos várias VMs de TPU e executaremos o mesmo SPMD Jax cada VM, esta página possui mais informações sobre os programas Jax multi-host. Este guia oferece uma rápida introdução à execução de programas JAX em uma fatia de POD TPU, mas atingiremos os principais pontos aqui.
$ gcloud alpha compute tpus tpu-vm create ${TPU_NAME}
--zone ${ZONE}
--accelerator-type v3-32
--version v2-alpha--command= e que ele deve ser executado em todas as nossas VMs (chamadas trabalhadores) com --worker=all . $ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--zone ${ZONE}
--worker=all
--command= " git clone --branch=main https://github.com/google-research/prompt-tuning && cd prompt-tuning && "
python3 -m pip install . -f https://storage.googleapis.com/jax-releases/libtpu_releases.html Escreva o script para treinar seu prompt. Incluímos um script de demonstração ( /prompt_tuning/scripts/sst2-xxl-demo.sh ) o treina um prompt para resolver o conjunto de dados SST2 usando o T5 1.1 LM100K XXL. Você pode usá -lo como ponto de partida ou apenas preencher os caminhos no seu balde de armazenamento em nuvem do Google para especificar onde deseja salvar seus resultados ( MODEL_DIR ) e onde cache os dados do TFDS ( TFDS_DATA_DIR ) ou defini -los como variáveis de ambiente.
Copie seu script de treinamento cada trabalhador. Se esta é a sua primeira vez em execução scp Você pode receber erro, execute o comando ssh-add /.../.ssh/google_compute_engine da mensagem de erro e tente novamente.
$ gcloud alpha compute tpus tpu-vm scp sst2-xxl-demo.sh ${TPU_NAME} :
--zone= ${ZONE}
--worker=all$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--zone ${ZONE}
--worker=all
--command= " ./sst2-xxl-demo.sh " Se um dos trabalhadores tiver um erro durante o treinamento, você ficará com processos que estão usando as TPUs nos outros trabalhadores. Isso o impedirá de reiniciar seu trabalho até que esses processos tivessem terminado e liberar a TPU. O comando a seguir deve encerrar todos esses processos. Você pode ver a página do kill Command Man voltando do trabalhador que teve o erro inicial.
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--zone ${ZONE}
--worker=all
--command= " sudo lsof -t /dev/accel0 | xargs kill -9 "Para treinar instruções usando peças personalizadas, como seu próprio conjunto de dados, siga as instruções T5X sobre componentes personalizados
Se você empacotar seu código como um pacote Python instalável por Pip, não estará vinculado a um único diretório e poderá usar python3 -m prompt_tuning.scripts.find_module {your_module} para ajudar a definir o gin_search_paths para que as configas de gin -bundas em sua bordada sejam encontradas. NOTA: Se você planeja agrupar as configurações de gin em um pacote instalável, verifique se os diretórios que contêm os arquivos de configuração têm um __init__.py como o gin exige que os arquivos estejam em um pacote python.
Se partes de seus componentes personalizados forem configuráveis em gin, eles precisam ser importados explicitamente em seus arquivos de gin; Se eles acabarem sendo importados após a análise dos arquivos de gin, eles causarão um erro. Se nenhuma de suas dependências contiver configurações de gin, você pode evitar a gravação de um arquivo de gin, passando --gin.MIXTURE_OR_TASK_MODULE="'path.to.your.module' . Isso importará automaticamente seu módulo e for conveniente para quando tudo o que você estiver fazendo for trocar os dados.
Nossa maneira sugerida de fazer inferência com um prompt é carregar o ponto de verificação original usado para inicializar o modelo e o prompt de um arquivo. Conforme explicado nesta seção, sobre o carregamento parcial T5X suporta carregar alguns parâmetros do modelo enquanto inicializa outras pessoas do zero. Usamos isso em conjunto com o inicializador do prompt from_array para recarregar os parâmetros congelados do ponto de verificação original e do arquivo prompt de um arquivo. O configs/runs/prompt_eval.gin configura esta configuração para você; Você só precisa fornecer um PROMPT_FILE . Se o seu modelo foi treinado com qualquer um dos arquivos de prompts/ configurações, poderá removê -los dos argumentos para o script de avaliação.
O script incluído sst2-demo-eval.sh mostra um exemplo de avaliação dessa maneira. Tudo o que é necessário é definir variáveis de ambiente EVAL_DIR e TFDS_DATA_DIR para os caminhos para armazenar a saída da avaliação e os conjuntos de dados do tensorflow respectivamente.
No T5X, o script de avaliação pressupõe que seu conjunto de dados possui rótulos e resultados dos resultados finais das funções métricas do seu conjunto de dados. O script de inferência não requer rótulos e, em vez disso, produz a previsão do seu modelo. Incluímos um arquivo análogo prompt_infer.gin para usar com o script de inferência.
Se você deseja fazer inferência ou avaliação com o ponto de verificação T5X que é produzido a partir de uma execução de treinamento de ajuste imediato, você pode usar a configuração (eval|infer).gin diretamente do T5X. Você precisará atualizar o utils.RestoreChekcpointConfig . Você deve definir path para o novo ponto de verificação, assignment_map=() e fallback_to_scratch=False .
Todo modelo, treinamento, avaliação, economia, restauração, etc. A configuração é feita via gim. Veja o repositório Gin-Config para uma introdução geral ao gin e a este primer
Seguimos o layout de configuração T5X:
runs/ :: contém configurações para o treinamento real do modelo. É aqui que coisas como conjunto de dados e configuração de avaliação vão.architectures/ :: contém configurações de como o modelo funciona. É aqui que são configuradas coisas como codificador-decodificador versus decodificador e compartilhamento de incorporação.models/ :: contém configurações que definem parâmetros específicos do modelo, como o número de camadas ou o tamanho da tabela de incorporação. Ele também configura coisas como o modelo de modelo T5X usado.models/decoding/ :: contém configurações fáceis de usar para trocar como o modelo gera texto durante a inferência, inclui configurações para pesquisa de feixe e amostragem de núcleos.models/sizes/ :: contém as várias configurações para criar modelos de tamanhos diferentes, eles são combinados com as versões padrão para criar uma versão de tamanho, por exemplo, t5_1_1_prompt.gin + sizes/large.gin cria um modelo grande T5 1.1. Algumas combinações comuns já disponíveis como arquivos de gin com a direita incluem ( t5_1_1_large_prompt.gin para o nosso exemplo acima). NOTA: Esses arquivos de tamanho precisam vir após o arquivo de modelo principal.prompts/ :: Nosso diretório extra contém configurações que definem a variável de gin PROMPT , permitindo fácil troca da inicialização do prompt baseado em qual arquivo prompt é adicionado como um argumento --gin_file (ele precisa vir após o arquivo models/ gin), Ao especificar --gin_file argumentos na linha de comando, o pedido é importante. A ordem geral em que os arquivos de gin devem ser especificados é:
models/*.ginprompts/*.ginmodels/sizes/*.gin*models/decoding/*.ginruns/*.gin O T5X possui alguns campos necessários como MIXTURE_OR_TASK_NAME ou TASK_FEATURE_LENGTHS . Adicionamos mais dois:
PROMPT_LENGTH :: O comprimento do prompt que estamos usando, isso é usado em alguns lugares diferentes para exigirmos como uma macro de gin que podemos referir em vários locais e garantir que os valores estejam sincronizados.PROMPT :: Esta é a configuração do módulo de prompt real que será usado nas subclasses do Floxformer PromptX . Nota: Atualmente, o ajuste rápido não suporta a embalagem de exemplos. Isso significa que nosso comprimento máximo do alvo só precisa ser longo o suficiente para se ajustar ao alvo para cada exemplo. Isso significa que nossa chave targets no mapeamento TASK_FEATURE_LENGTHS pode ser muito mais curta, por exemplo, em torno de 4 para muitas tarefas de superclue (Wang et al., 2019), em comparação com 62 e é o que é o padrão do P5X.
Existem várias opções para a inicialização do parâmetro prompt. Apoiamos os vários métodos na Seção 3.2 Nosso artigo, bem como a inicialização de um arquivo. O último permite fazer coisas como trem no boolq a partir de um prompt de aprendizado no MNLI.
Todos os inicializadores seguem a API inicializadora de linho de ser uma função parametrizada que retorna um fechamento sobre a função de inicialização. A função de inicialização real sempre tem a assinatura de
def initializer ( rng : Array , shape : Sequence [ int ]) -> Array :
... Fornecemos cada esquema de inicialização como um arquivo de configuração do GIN no diretório configs/prompts . Eles podem ser usados incluindo o arquivo gin com o --gin_file=path/to/configs/prompts/scheme.gin . Esse arquivo precisa vir após o arquivo do modelo principal, caso contrário, o método padrão (uniforme aleatório) substituirá o que você selecionou. Alguns desses métodos de inicialização exigirão que você defina valores extras de gin, embora seja um sinalizador de substituição em um de seus arquivos de gin.
Uniforme aleatório
Uma inicialização padrão e aleatória semelhante ao que as pessoas usaram para incorporar a inicialização. Este é o padrão e nenhum arquivo de gin é necessário. A escala dos valores aleatórios pode ser ajustada substituindo prompt_init/linen.initializers.uniform.scale=N .
Vocabulário amostrado
Amostra a incorporação de token para usar como inicialização para cada posição de prompt com o inicializador from_sample_of_embeddings . Você pode limitar a amostragem aos primeiros n incorporação com o prompt_init/prompts.from_samples_of_embeddings.population_size parâmetro.
Isso pode ser usado com --gin_file=prompt_tuning/configs/prompts/from_sampled_vocab.gin . Este método usa a tabela de incorporação extraída do ponto de verificação inicial do modelo. Você também pode fornecer seu próprio arquivo de incorporação com --gin_file=prompt_tuning/configs/prompts/from_sampled_vocab_numpy.gin . Este método requer que você forneça um valor para EMBEDDING_FILE que seja uma matriz numpy da tabela de incorporação do modelo. Isso pode ser extraído do ponto de verificação do modelo usando o Prompt_Tuning.scripts.extract_variable.
Rótulo da classe
Apoiamos a inicialização do Timesteps prompts com a incorporação de rótulos de classe (também conhecidos como verbalizadores ) através do inicializador from_embedded_list . Usuários que fornecem uma lista de palavras (etiquetas de classe) para usar. Cada palavras são tokenizadas por um vocabulário fornecido; incorporado com uma tabela de vocabulário fornecida; agregado, se necessário, entre sub-tocros; e usado para inicializar uma etapa de tempo imediata. Se os tokens fornecidos não cobrirem o comprimento completo completo, os tokens ausentes serão inicializados usando o Inicializador de Fall Back fornecido.
Podemos corresponder ao papel, onde os tokens prompts não preenchidos são preenchidos pela amostragem da tabela de incorporação, compondo essa inicialização com a acima. Ele pode ser usado com --gin_file=prompt_tuning/configs/prompts/from_class_labels.gin . Isso requer configuração de CLASS_LABELS , que é uma lista das palavras que você deseja incorporar como inicialização rápida. Você também pode fornecer seu próprio arquivo de incorporação (que é o mesmo que acima) com --gin_file=prompt_tuning/configs/prompts/from_class_labels_numpy.gin . Além disso, isso requer a configuração EMBEDDING_FILE .
De string
Também apoiamos a inicialização de um prompt com a incorporação de alguma string, geralmente usada para começar a partir de um prompt discreto ou uma descrição da tarefa. Isso usa o inicializador from_embedded_string . A string é tokenizada pelo vocabulário fornecido, cada token é procurado na tabela de incorporação fornecida e a representação incorporada resultante da string é usada como uma inicialização rápida. Se os tokens fornecidos não cobrirem o comprimento completo completo, os tokens ausentes serão inicializados usando o Inicializador de Fall Back fornecido.
Nota: O vocabulário apenas converte a string em uma sequência de IDs, você precisará garantir que a string corresponda ao resultado de qualquer formatação de texto (espaços em torno da pontuação, etc.) que sua tarefa seqio faz.
Do arquivo
Você também pode carregar um prompt de um arquivo com o inicializador from_array para ativar a transferência entre tarefas. Isso é feito com --gin_file=prompt_tuning/configs/prompts/from_file.gin . Isso requer definir PROMPT_FILE com um caminho para o arquivo Numpy com o prompt para carregar. As versões Numpy do prompt são emitidas por padrão ao treinar, mas o prompt também pode ser extraído com o script mencionado acima.
Lançamos os pontos de verificação nativos do T5X dos pontos de verificação T5 1.1 que tiveram 100 mil etapas de adaptação para modelos de idiomas.
Estes são convertidos a partir dos pontos de verificação do TensorFlow de malha pública.
Lançamos instruções pré -tenhadas em uma variedade de tarefas e planejamos adicioná -las ao longo do tempo.
Os avisos podem ser encontrados no diretório pretrained_prompts . A partir daí, cada grupo de subdiretores solicita o modelo pelo qual foram treinados. A maneira mais fácil de fazer referência a esses avisos que são incluídos na biblioteca é:
--PROMPT_FILE= ` python3 -m prompt_tuning.scripts.find_module prompt_tuning ` /pretrained_prompts/{MODEL_SIZE}/{PROMPT}.npy Devido à aleatoriedade inerente à computação paralela, existem algumas configurações que precisam corresponder entre treinamento e avaliação para obter exatamente os mesmos números. Cada subdiretório do modelo possui um README.md Especifica quais devem ser essas configurações. As configurações mais importantes a serem combinadas são o tamanho do lote, a topologia da TPU e o partição do paralelismo do modelo. As tabelas incluem as pontuações que você deve esperar se você usa esses avisos em t5x.eval
Esta é uma coleção de recursos adicionais sobre ajuste imediato.
Se você usar este trabalho como ponto de partida, cite
@inproceedings { lester-etal-2021-power ,
title = " The Power of Scale for Parameter-Efficient Prompt Tuning " ,
author = " Lester, Brian and
Al-Rfou, Rami and
Constant, Noah " ,
booktitle = " Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing " ,
month = nov,
year = " 2021 " ,
address = " Online and Punta Cana, Dominican Republic " ,
publisher = " Association for Computational Linguistics " ,
url = " https://aclanthology.org/2021.emnlp-main.243 " ,
doi = " 10.18653/v1/2021.emnlp-main.243 " ,
pages = " 3045--3059 " ,
}Este não é um produto do Google oficialmente suportado.