Dynamic memory networks plus Pytorch
1.0.0
DMN+在Pytorch中實現了BABI 10K數據集上的問題。
| 文件 | 描述 |
|---|---|
babi_loader.py | Babi Pytorch數據集的聲明 |
babi_main.py | 包含DMN+模型和培訓代碼 |
fetch_data.sh | 殼腳本以獲取BABI任務(來自Theano的DMN) |
安裝Pytorch V0.1.12和Python 3.6.x(用於字面字符串插值)
運行隨附的外殼腳本以獲取數據
chmod +x fetch_data.sh
./fetch_data.sh
運行主要的Python代碼
python babi_main.py
與Xiong等人相比,低精度可能是由於重量衰減設置不同或模型的不穩定性所致。
在某些任務上,精度在多個運行中不穩定。這在QA3,QA17和QA18上尤其有問題。為了解決這個問題,我們使用隨機初始化重複了10次訓練,並評估了達到最低驗證集損失的模型。
您可以在這裡找到據位的模型
| 任務ID | 這個存儲庫 | Xiong等 |
|---|---|---|
| 1 | 100% | 100% |
| 2 | 96.8% | 99.7% |
| 3 | 89.2% | 98.9% |
| 4 | 100% | 100% |
| 5 | 99.5% | 99.5% |
| 6 | 100% | 100% |
| 7 | 97.8% | 97.6% |
| 8 | 100% | 100% |
| 9 | 100% | 100% |
| 10 | 100% | 100% |
| 11 | 100% | 100% |
| 12 | 100% | 100% |
| 13 | 100% | 100% |
| 14 | 99% | 99.8% |
| 15 | 100% | 100% |
| 16 | 51.6% | 54.7% |
| 17 | 86.4% | 95.8% |
| 18 | 97.9% | 97.9% |
| 19 | 99.7% | 100% |
| 20 | 100% | 100% |