論文の公式コード「目標指向のダイアログシステムのためのコンテキスト対応言語モデリング」
プロジェクトサイト| arxiv
conda create --name CALM python=3.9.7conda activate CALMpip install -r requirements.txtconda install pytorch==1.9.0 cudatoolkit=11.3 -c pytorch -c conda-forgeexport PYTHONPATH="$PWD/offline_airdialogue"outputs/フォルダーには、メインモデルのチェックポイント、タスクの前のモデル、および顧客ボットが含まれています。 (注:すべてのトレーニングはデフォルトでWandBを使用します。WandB同期を設定でオフにすることができます。)
cd scripts/train
Data-Parallel Multi-GPUトレーニングを実行するには、以下のコマンドのいずれかで、 python <script_path> python -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env <script_path>に置き換えます。
穏やかな事前に
(補助損失関数の2つのバリアント)
スクリプト: python train_pretrain_table_agent.py
config: config/train_pretrain_table_agent.yaml
スクリプト: python train_pretrain_simplified_aux_gpt2.py
config: config/train_pretrain_simplified_aux_gpt2.yaml
顧客ボットのトレーニング
python train_customer.pyconfig/train_customer_bot.yaml落ち着いたトレーニング
(補助損失関数の2つのバリアント)
スクリプト: python train_real_table_agent.py
config: config/train_real_table_agent.yaml
スクリプト: python train_simplified_aux_gpt2.py
config: config/train_simplified_aux_agent.yaml
トレーニング標準LM
python train_basic_agent.pyconfig/train_basic_agent.yamlモデルベースのロールアウト計画の報酬モデルのトレーニング
python train_constraint_parser.pyconfig/train_constraint_parser.yaml cd scripts/eval
シミュレートされた評価
python selfplay_eval.pyconfig/selfplay_eval.yamlselfplay/outputs_fileによって指定された場所に保存されます。セルフプレイの成功率を印刷するには: python compute_results.py --results_file <your_eval_outputs_file>CUDA_VISIBLE_DEVICES=<comma_seperated_list_of_gpu_indicies>でコマンドをプレフィックスします言語品質評価
python language_quality_eval.pyconfig/language_eval.yamlpython -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env language_quality_eval.py