CALM Dialogue
1.0.0
該論文的官方代碼“面向目標對話系統的上下文意識語言建模”
項目網站| arxiv
conda create --name CALM python=3.9.7conda activate CALMpip install -r requirements.txtconda install pytorch==1.9.0 cudatoolkit=11.3 -c pytorch -c conda-forgeexport PYTHONPATH="$PWD/offline_airdialogue"outputs/文件夾包含我們的主要模型,我們的任務預估計模型和客戶機器人的檢查點。 (注意:默認情況下,所有訓練運行都使用WANDB,您可以在配置中關閉WANDB同步。)
cd scripts/train
要運行數據並行多GPU培訓,請在下面的任何命令上替換python <script_path>用python -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env <script_path> 。
預處理平靜
(輔助損耗函數的兩個變體)
腳本: python train_pretrain_table_agent.py
config: config/train_pretrain_table_agent.yaml
腳本: python train_pretrain_simplified_aux_gpt2.py
config: config/train_pretrain_simplified_aux_gpt2.yaml
培訓客戶機器人
python train_customer.pyconfig/train_customer_bot.yaml訓練平靜
(輔助損耗函數的兩個變體)
腳本: python train_real_table_agent.py
config: config/train_real_table_agent.yaml
腳本: python train_simplified_aux_gpt2.py
config: config/train_simplified_aux_agent.yaml
培訓標準LM
python train_basic_agent.pyconfig/train_basic_agent.yaml培訓獎勵模型為基於模型的推出計劃
python train_constraint_parser.pyconfig/train_constraint_parser.yaml cd scripts/eval
模擬評估
python selfplay_eval.pyconfig/selfplay_eval.yamlselfplay/outputs_file指定的位置。要打印自我播放的成功率: python compute_results.py --results_file <your_eval_outputs_file>CUDA_VISIBLE_DEVICES=<comma_seperated_list_of_gpu_indicies>前綴。語言質量評估
python language_quality_eval.pyconfig/language_eval.yamlpython -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env language_quality_eval.py