DMN+ Implementación en Pytorch para la respuesta de preguntas en el conjunto de datos BABI 10K.
| archivo | descripción |
|---|---|
babi_loader.py | Declaración de la clase del conjunto de datos Babi Pytorch |
babi_main.py | Contiene modelo DMN+ y código de entrenamiento |
fetch_data.sh | Script de shell para buscar tareas de Babi (de DMNS en theo) |
Instale Pytorch V0.1.12 y Python 3.6.x (para interpolación literal de cadenas)
Ejecute el script de shell incluido para obtener los datos
chmod +x fetch_data.sh
./fetch_data.sh
Ejecute el código principal de Python
python babi_main.py
Las bajas precisiones en comparación con Xiong et al, pueden debido a la configuración de descomposición de peso diferente o la inestabilidad del modelo.
En algunas tareas, la precisión no fue estable en múltiples ejecuciones. Esto fue particularmente problemático para QA3, QA17 y QA18. Para resolver esto, repetimos el entrenamiento 10 veces utilizando inicializaciones aleatorias y evaluamos el modelo que logró la pérdida del conjunto de validación más baja.
Puedes encontrar modelos previos a la petróleo aquí
| ID de tarea | Este repositorio | Xiong et al |
|---|---|---|
| 1 | 100% | 100% |
| 2 | 96.8% | 99.7% |
| 3 | 89.2% | 98.9% |
| 4 | 100% | 100% |
| 5 | 99.5% | 99.5% |
| 6 | 100% | 100% |
| 7 | 97.8% | 97.6% |
| 8 | 100% | 100% |
| 9 | 100% | 100% |
| 10 | 100% | 100% |
| 11 | 100% | 100% |
| 12 | 100% | 100% |
| 13 | 100% | 100% |
| 14 | 99% | 99.8% |
| 15 | 100% | 100% |
| 16 | 51.6% | 54.7% |
| 17 | 86.4% | 95.8% |
| 18 | 97.9% | 97.9% |
| 19 | 99.7% | 100% |
| 20 | 100% | 100% |