Implémentation DMN + dans Pytorch pour répondre à des questions sur l'ensemble de données BABI 10K.
| déposer | description |
|---|---|
babi_loader.py | Déclaration de classe de données Babi Pytorch |
babi_main.py | Contient le modèle DMN + et le code de formation |
fetch_data.sh | script shell pour récupérer les tâches Babi (de DMNS dans Theano) |
Installez Pytorch V0.1.12 et Python 3.6.x (pour l'interpolation de la chaîne littérale)
Exécutez le script shell inclus pour récupérer les données
chmod +x fetch_data.sh
./fetch_data.sh
Exécutez le code Python principal
python babi_main.py
De faibles précisions par rapport aux Xiong et al sont dues à un réglage différent de désintégration du poids ou à l'instabilité du modèle.
Sur certaines tâches, la précision n'était pas stable sur plusieurs courses. Cela était particulièrement problématique sur QA3, QA17 et QA18. Pour résoudre ce problème, nous avons répété la formation 10 fois en utilisant des initialisations aléatoires et évalué le modèle qui a réalisé la perte d'ensemble de validation la plus faible.
Vous pouvez trouver des modèles pré-entraînés ici
| ID de tâche | Ce repo | Xiong et al |
|---|---|---|
| 1 | 100% | 100% |
| 2 | 96,8% | 99,7% |
| 3 | 89,2% | 98,9% |
| 4 | 100% | 100% |
| 5 | 99,5% | 99,5% |
| 6 | 100% | 100% |
| 7 | 97,8% | 97,6% |
| 8 | 100% | 100% |
| 9 | 100% | 100% |
| 10 | 100% | 100% |
| 11 | 100% | 100% |
| 12 | 100% | 100% |
| 13 | 100% | 100% |
| 14 | 99% | 99,8% |
| 15 | 100% | 100% |
| 16 | 51,6% | 54,7% |
| 17 | 86,4% | 95,8% |
| 18 | 97,9% | 97,9% |
| 19 | 99,7% | 100% |
| 20 | 100% | 100% |