Code officiel pour l'article "Modélisation du langage contextuel pour les systèmes de dialogue orientés vers les objectifs"
site du projet | arxiv
conda create --name CALM python=3.9.7conda activate CALMpip install -r requirements.txtconda install pytorch==1.9.0 cudatoolkit=11.3 -c pytorch -c conda-forgeexport PYTHONPATH="$PWD/offline_airdialogue"outputs/ dossier contient des points de contrôle pour notre modèle principal, notre modèle de tâche pré-entraîné et notre bot client. (Remarque: toutes les exécutions de formation utilisent WandB par défaut, vous pouvez désactiver la synchronisation Wandb dans la configuration.)
cd scripts/train
Pour exécuter une formation multi-GPU de données parallèle, sur l'une des commandes ci-dessous, remplacez python <script_path> par python -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env <script_path> .
Pré-entraîneur calme
(Deux variantes de la fonction de perte auxiliaire)
Script: python train_pretrain_table_agent.py
Config: config/train_pretrain_table_agent.yaml
Script: python train_pretrain_simplified_aux_gpt2.py
Config: config/train_pretrain_simplified_aux_gpt2.yaml
Formation du bot client
python train_customer.pyconfig/train_customer_bot.yaml Entraînement de formation calme
(Deux variantes de la fonction de perte auxiliaire)
Script: python train_real_table_agent.py
Config: config/train_real_table_agent.yaml
Script: python train_simplified_aux_gpt2.py
Config: config/train_simplified_aux_agent.yaml
Formation standard LM
python train_basic_agent.pyconfig/train_basic_agent.yamlFormation du modèle de récompense pour la planification de déploiement basée sur le modèle
python train_constraint_parser.pyconfig/train_constraint_parser.yaml cd scripts/eval
Évaluation simulée
python selfplay_eval.pyconfig/selfplay_eval.yamlselfplay/outputs_file dans la configuration. Pour imprimer le taux de réussite de la course d'auto-jeu: python compute_results.py --results_file <your_eval_outputs_file>CUDA_VISIBLE_DEVICES=<comma_seperated_list_of_gpu_indicies>Évaluation de la qualité du langage
python language_quality_eval.pyconfig/language_eval.yamlpython -m torch.distributed.launch --nproc_per_node <n_GPUs> --use_env language_quality_eval.py