Código oficial para el documento "Modelado de lenguaje consciente de contexto para sistemas de diálogo orientados a objetivos"
Sitio del proyecto | arxiv
conda create --name CALM python=3.9.7conda activate CALMpip install -r requirements.txtconda install pytorch==1.9.0 cudatoolkit=11.3 -c pytorch -c conda-forgeexport PYTHONPATH="$PWD/offline_airdialogue"outputs/ carpeta contienen puntos de control para nuestro modelo principal, nuestro modelo de pretrada de tareas y nuestro bot de clientes. (Nota: todas las ejecuciones de entrenamiento usan WandB de forma predeterminada, puede desactivar la sincronización de WandB en la configuración).
cd scripts/train
Para ejecutar la capacitación multi-GPU de datos paralelos, en cualquiera de los comandos a continuación, reemplace python <script_path> con python -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env <script_path> .
Prueba de calma
(Dos variantes de la función de pérdida auxiliar)
Script: python train_pretrain_table_agent.py
config: config/train_pretrain_table_agent.yaml
Script: python train_pretrain_simplified_aux_gpt2.py
config: config/train_pretrain_simplified_aux_gpt2.yaml
Capacitar al bot del cliente
python train_customer.pyconfig/train_customer_bot.yaml Calma de entrenamiento
(Dos variantes de la función de pérdida auxiliar)
Script: python train_real_table_agent.py
config: config/train_real_table_agent.yaml
Script: python train_simplified_aux_gpt2.py
config: config/train_simplified_aux_agent.yaml
Estándar de entrenamiento LM
python train_basic_agent.pyconfig/train_basic_agent.yamlCapacitación del modelo de recompensa para la planificación del despliegue basada en el modelo
python train_constraint_parser.pyconfig/train_constraint_parser.yaml cd scripts/eval
Evaluación simulada
python selfplay_eval.pyconfig/selfplay_eval.yamlselfplay/outputs_file en la configuración. Para imprimir la tasa de éxito para la ejecución de autoplay: python compute_results.py --results_file <your_eval_outputs_file>CUDA_VISIBLE_DEVICES=<comma_seperated_list_of_gpu_indicies>Evaluación de calidad del idioma
python language_quality_eval.pyconfig/language_eval.yamlpython -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env language_quality_eval.py