Dynamic memory networks plus Pytorch
1.0.0
DMN+在Pytorch中实现了BABI 10K数据集上的问题。
| 文件 | 描述 |
|---|---|
babi_loader.py | Babi Pytorch数据集的声明 |
babi_main.py | 包含DMN+模型和培训代码 |
fetch_data.sh | 壳脚本以获取BABI任务(来自Theano的DMN) |
安装Pytorch V0.1.12和Python 3.6.x(用于字面字符串插值)
运行随附的外壳脚本以获取数据
chmod +x fetch_data.sh
./fetch_data.sh
运行主要的Python代码
python babi_main.py
与Xiong等人相比,低精度可能是由于重量衰减设置不同或模型的不稳定性所致。
在某些任务上,精度在多个运行中不稳定。这在QA3,QA17和QA18上尤其有问题。为了解决这个问题,我们使用随机初始化重复了10次训练,并评估了达到最低验证集损失的模型。
您可以在这里找到据位的模型
| 任务ID | 这个存储库 | Xiong等 |
|---|---|---|
| 1 | 100% | 100% |
| 2 | 96.8% | 99.7% |
| 3 | 89.2% | 98.9% |
| 4 | 100% | 100% |
| 5 | 99.5% | 99.5% |
| 6 | 100% | 100% |
| 7 | 97.8% | 97.6% |
| 8 | 100% | 100% |
| 9 | 100% | 100% |
| 10 | 100% | 100% |
| 11 | 100% | 100% |
| 12 | 100% | 100% |
| 13 | 100% | 100% |
| 14 | 99% | 99.8% |
| 15 | 100% | 100% |
| 16 | 51.6% | 54.7% |
| 17 | 86.4% | 95.8% |
| 18 | 97.9% | 97.9% |
| 19 | 99.7% | 100% |
| 20 | 100% | 100% |