DMN+ Implementierung in Pytorch zur Beantwortung der Beantwortung des BABI 10K -Datensatzes.
| Datei | Beschreibung |
|---|---|
babi_loader.py | Erklärung der Babi Pytorch -Datensatzklasse |
babi_main.py | Enthält DMN+ Modell- und Trainingscode |
fetch_data.sh | Shell -Skript zum Abrufen von Babi -Aufgaben (von DMNs in Theano) |
Installieren Sie Pytorch v0.1.12 und Python 3.6.x (für die wörtliche Saiten -Interpolation)
Führen Sie das mitgelieferte Shell -Skript aus, um die Daten abzurufen
chmod +x fetch_data.sh
./fetch_data.sh
Führen Sie den Haupt -Python -Code aus
python babi_main.py
Niedrige Genauigkeiten im Vergleich zu Xiong et al. Es kann auf eine unterschiedliche Gewichtsabfalleinstellung oder die Instabilität des Modells zurückzuführen sein.
Bei einigen Aufgaben war die Genauigkeit in mehreren Läufen nicht stabil. Dies war besonders problematisch für QA3, QA17 und QA18. Um dies zu lösen, wiederholten wir das Training 10 -fach mit zufälligen Initialisierungen und bewerteten das Modell, das den niedrigsten Verlust des Validierungssatzes erreichte.
Hier finden Sie vorgelegte Modelle
| Aufgaben -ID | Dieses Repo | Xiong et al |
|---|---|---|
| 1 | 100% | 100% |
| 2 | 96,8% | 99,7% |
| 3 | 89,2% | 98,9% |
| 4 | 100% | 100% |
| 5 | 99,5% | 99,5% |
| 6 | 100% | 100% |
| 7 | 97,8% | 97,6% |
| 8 | 100% | 100% |
| 9 | 100% | 100% |
| 10 | 100% | 100% |
| 11 | 100% | 100% |
| 12 | 100% | 100% |
| 13 | 100% | 100% |
| 14 | 99% | 99,8% |
| 15 | 100% | 100% |
| 16 | 51,6% | 54,7% |
| 17 | 86,4% | 95,8% |
| 18 | 97,9% | 97,9% |
| 19 | 99,7% | 100% |
| 20 | 100% | 100% |