memn2n
1.0.0
使用TensorFlow實現端到端內存網絡,具有類似Sklearn的接口。任務來自BABL數據集。
git clone [email protected]:domluna/memn2n.git
mkdir ./memn2n/data/
cd ./memn2n/data/
wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz
tar xzvf ./tasks_1-20_v1-2.tar.gz
cd ../
python single.py
執行一個BABI任務
在所有BABI任務上運行聯合模型
這些文件也是用法的一個很好的例子。
為了使任務通過它必須達到95%+測試精度。在1K數據上的單個任務上測量。
通過:1,4,12,15,20
其他幾個任務具有80%+測試精度。
隨機梯度下降優化器與端到端內存網絡第4.2節中指定的退火學習率計劃一起使用
使用了以下參數:
| 任務 | 培訓準確性 | 驗證精度 | 測試準確性 |
|---|---|---|---|
| 1 | 1.0 | 1.0 | 1.0 |
| 2 | 1.0 | 0.86 | 0.83 |
| 3 | 1.0 | 0.64 | 0.54 |
| 4 | 1.0 | 0.99 | 0.98 |
| 5 | 1.0 | 0.94 | 0.87 |
| 6 | 1.0 | 0.97 | 0.92 |
| 7 | 1.0 | 0.89 | 0.84 |
| 8 | 1.0 | 0.93 | 0.86 |
| 9 | 1.0 | 0.86 | 0.90 |
| 10 | 1.0 | 0.80 | 0.78 |
| 11 | 1.0 | 0.92 | 0.84 |
| 12 | 1.0 | 1.0 | 1.0 |
| 13 | 0.99 | 0.94 | 0.90 |
| 14 | 1.0 | 0.97 | 0.93 |
| 15 | 1.0 | 1.0 | 1.0 |
| 16 | 0.81 | 0.47 | 0.44 |
| 17 | 0.76 | 0.65 | 0.52 |
| 18 | 0.97 | 0.96 | 0.88 |
| 19 | 0.40 | 0.17 | 0.13 |
| 20 | 1.0 | 1.0 | 1.0 |
通過:1,6,9,10,12,13,15,20
同樣,隨機梯度下降優化器與端到端內存網絡第4.2節中指定的退火學習率計劃一起使用
使用了以下參數:
| 任務 | 培訓準確性 | 驗證精度 | 測試準確性 |
|---|---|---|---|
| 1 | 1.0 | 0.99 | 0.999 |
| 2 | 1.0 | 0.84 | 0.849 |
| 3 | 0.99 | 0.72 | 0.715 |
| 4 | 0.96 | 0.86 | 0.851 |
| 5 | 1.0 | 0.92 | 0.865 |
| 6 | 1.0 | 0.97 | 0.964 |
| 7 | 0.96 | 0.87 | 0.851 |
| 8 | 0.99 | 0.89 | 0.898 |
| 9 | 0.99 | 0.96 | 0.96 |
| 10 | 1.0 | 0.96 | 0.928 |
| 11 | 1.0 | 0.98 | 0.93 |
| 12 | 1.0 | 0.98 | 0.982 |
| 13 | 0.99 | 0.98 | 0.976 |
| 14 | 1.0 | 0.81 | 0.877 |
| 15 | 1.0 | 1.0 | 0.983 |
| 16 | 0.64 | 0.45 | 0.44 |
| 17 | 0.77 | 0.64 | 0.547 |
| 18 | 0.85 | 0.71 | 0.586 |
| 19 | 0.24 | 0.07 | 0.104 |
| 20 | 1.0 | 1.0 | 0.996 |
單個任務結果來自單個任務模型的10個重複跟踪,所有20個任務都具有不同的隨機初始化。上表顯示了每個任務驗證精度最低的模型的性能。
聯合訓練結果來自所有任務的共同模型的10個重複步道。驗證精度通過最多任務(> = 0.95)的單個模型的性能如上表(inter_scores_run2.csv)。所有10次運行的分數都位於結果/目錄中。