memn2n
1.0.0
使用TensorFlow实现端到端内存网络,具有类似Sklearn的接口。任务来自BABL数据集。
git clone [email protected]:domluna/memn2n.git
mkdir ./memn2n/data/
cd ./memn2n/data/
wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz
tar xzvf ./tasks_1-20_v1-2.tar.gz
cd ../
python single.py
执行一个BABI任务
在所有BABI任务上运行联合模型
这些文件也是用法的一个很好的例子。
为了使任务通过它必须达到95%+测试精度。在1K数据上的单个任务上测量。
通过:1,4,12,15,20
其他几个任务具有80%+测试精度。
随机梯度下降优化器与端到端内存网络第4.2节中指定的退火学习率计划一起使用
使用了以下参数:
| 任务 | 培训准确性 | 验证精度 | 测试准确性 |
|---|---|---|---|
| 1 | 1.0 | 1.0 | 1.0 |
| 2 | 1.0 | 0.86 | 0.83 |
| 3 | 1.0 | 0.64 | 0.54 |
| 4 | 1.0 | 0.99 | 0.98 |
| 5 | 1.0 | 0.94 | 0.87 |
| 6 | 1.0 | 0.97 | 0.92 |
| 7 | 1.0 | 0.89 | 0.84 |
| 8 | 1.0 | 0.93 | 0.86 |
| 9 | 1.0 | 0.86 | 0.90 |
| 10 | 1.0 | 0.80 | 0.78 |
| 11 | 1.0 | 0.92 | 0.84 |
| 12 | 1.0 | 1.0 | 1.0 |
| 13 | 0.99 | 0.94 | 0.90 |
| 14 | 1.0 | 0.97 | 0.93 |
| 15 | 1.0 | 1.0 | 1.0 |
| 16 | 0.81 | 0.47 | 0.44 |
| 17 | 0.76 | 0.65 | 0.52 |
| 18 | 0.97 | 0.96 | 0.88 |
| 19 | 0.40 | 0.17 | 0.13 |
| 20 | 1.0 | 1.0 | 1.0 |
通过:1,6,9,10,12,13,15,20
同样,随机梯度下降优化器与端到端内存网络第4.2节中指定的退火学习率计划一起使用
使用了以下参数:
| 任务 | 培训准确性 | 验证精度 | 测试准确性 |
|---|---|---|---|
| 1 | 1.0 | 0.99 | 0.999 |
| 2 | 1.0 | 0.84 | 0.849 |
| 3 | 0.99 | 0.72 | 0.715 |
| 4 | 0.96 | 0.86 | 0.851 |
| 5 | 1.0 | 0.92 | 0.865 |
| 6 | 1.0 | 0.97 | 0.964 |
| 7 | 0.96 | 0.87 | 0.851 |
| 8 | 0.99 | 0.89 | 0.898 |
| 9 | 0.99 | 0.96 | 0.96 |
| 10 | 1.0 | 0.96 | 0.928 |
| 11 | 1.0 | 0.98 | 0.93 |
| 12 | 1.0 | 0.98 | 0.982 |
| 13 | 0.99 | 0.98 | 0.976 |
| 14 | 1.0 | 0.81 | 0.877 |
| 15 | 1.0 | 1.0 | 0.983 |
| 16 | 0.64 | 0.45 | 0.44 |
| 17 | 0.77 | 0.64 | 0.547 |
| 18 | 0.85 | 0.71 | 0.586 |
| 19 | 0.24 | 0.07 | 0.104 |
| 20 | 1.0 | 1.0 | 0.996 |
单个任务结果来自单个任务模型的10个重复跟踪,所有20个任务都具有不同的随机初始化。上表显示了每个任务验证精度最低的模型的性能。
联合训练结果来自所有任务的共同模型的10个重复步道。验证精度通过最多任务(> = 0.95)的单个模型的性能如上表(inter_scores_run2.csv)。所有10次运行的分数都位于结果/目录中。