Código oficial para o artigo "Modelagem de idiomas com conhecimento de contexto para sistemas de diálogo orientados a objetivos"
Site do projeto | arxiv
conda create --name CALM python=3.9.7conda activate CALMpip install -r requirements.txtconda install pytorch==1.9.0 cudatoolkit=11.3 -c pytorch -c conda-forgeexport PYTHONPATH="$PWD/offline_airdialogue"outputs/ pasta contém pontos de verificação para o nosso modelo principal, nosso modelo pré -traçado de tarefas e nosso bot de cliente. (Nota: todas as execuções de treinamento usam Wandb por padrão, você pode desativar a sincronização do Wandb na configuração.)
cd scripts/train
Para executar o treinamento multi-gpu paralelo de dados, em qualquer um dos comandos abaixo, substitua python <script_path> por python -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env <script_path> .
Calmo pré -treinamento
(duas variantes da função de perda auxiliar)
Script: python train_pretrain_table_agent.py
Config: config/train_pretrain_table_agent.yaml
Script: python train_pretrain_simplified_aux_gpt2.py
Config: config/train_pretrain_simplified_aux_gpt2.yaml
Treinando o bot do cliente
python train_customer.pyconfig/train_customer_bot.yaml Treinar calma
(duas variantes da função de perda auxiliar)
Script: python train_real_table_agent.py
Config: config/train_real_table_agent.yaml
Script: python train_simplified_aux_gpt2.py
Config: config/train_simplified_aux_agent.yaml
Treinamento padrão LM
python train_basic_agent.pyconfig/train_basic_agent.yamlTreinando o modelo de recompensa para planejamento de lançamento baseado em modelo
python train_constraint_parser.pyconfig/train_constraint_parser.yaml cd scripts/eval
Avaliação simulada
python selfplay_eval.pyconfig/selfplay_eval.yamlselfplay/outputs_file na configuração. Para imprimir a taxa de sucesso para a execução do auto -jogo: python compute_results.py --results_file <your_eval_outputs_file>CUDA_VISIBLE_DEVICES=<comma_seperated_list_of_gpu_indicies>Avaliação da qualidade do idioma
python language_quality_eval.pyconfig/language_eval.yamlpython -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env language_quality_eval.py