memn2n
1.0.0
Tensorflowを使用したSklearnのようなインターフェイスを使用したエンドツーエンドメモリネットワークの実装。タスクはBABLデータセットからのものです。
git clone [email protected]:domluna/memn2n.git
mkdir ./memn2n/data/
cd ./memn2n/data/
wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz
tar xzvf ./tasks_1-20_v1-2.tar.gz
cd ../
python single.py
単一のbabiタスクを実行します
すべてのバビタスクで共同モデルを実行します
これらのファイルは、使用の良い例でもあります。
パスするタスクの場合、95%以上のテスト精度を満たす必要があります。 1Kデータの単一タスクで測定されます。
パス:1,4,12,15,20
他のいくつかのタスクには、80%以上のテスト精度があります。
確率勾配降下オプティマイザーは、エンドツーエンドメモリネットワークのセクション4.2で指定されているアニール学習率スケジュールで使用されました
次のパラメーションが使用されました。
| タスク | トレーニングの正確性 | 検証の精度 | テストの精度 |
|---|---|---|---|
| 1 | 1.0 | 1.0 | 1.0 |
| 2 | 1.0 | 0.86 | 0.83 |
| 3 | 1.0 | 0.64 | 0.54 |
| 4 | 1.0 | 0.99 | 0.98 |
| 5 | 1.0 | 0.94 | 0.87 |
| 6 | 1.0 | 0.97 | 0.92 |
| 7 | 1.0 | 0.89 | 0.84 |
| 8 | 1.0 | 0.93 | 0.86 |
| 9 | 1.0 | 0.86 | 0.90 |
| 10 | 1.0 | 0.80 | 0.78 |
| 11 | 1.0 | 0.92 | 0.84 |
| 12 | 1.0 | 1.0 | 1.0 |
| 13 | 0.99 | 0.94 | 0.90 |
| 14 | 1.0 | 0.97 | 0.93 |
| 15 | 1.0 | 1.0 | 1.0 |
| 16 | 0.81 | 0.47 | 0.44 |
| 17 | 0.76 | 0.65 | 0.52 |
| 18 | 0.97 | 0.96 | 0.88 |
| 19 | 0.40 | 0.17 | 0.13 |
| 20 | 1.0 | 1.0 | 1.0 |
パス:1,6,9,10,12,13,15,20
再び確率勾配降下オプティマイザーは、エンドツーエンドメモリネットワークのセクション4.2で指定されているアニール学習率スケジュールで使用されました
次のパラメーションが使用されました。
| タスク | トレーニングの正確性 | 検証の精度 | テストの精度 |
|---|---|---|---|
| 1 | 1.0 | 0.99 | 0.999 |
| 2 | 1.0 | 0.84 | 0.849 |
| 3 | 0.99 | 0.72 | 0.715 |
| 4 | 0.96 | 0.86 | 0.851 |
| 5 | 1.0 | 0.92 | 0.865 |
| 6 | 1.0 | 0.97 | 0.964 |
| 7 | 0.96 | 0.87 | 0.851 |
| 8 | 0.99 | 0.89 | 0.898 |
| 9 | 0.99 | 0.96 | 0.96 |
| 10 | 1.0 | 0.96 | 0.928 |
| 11 | 1.0 | 0.98 | 0.93 |
| 12 | 1.0 | 0.98 | 0.982 |
| 13 | 0.99 | 0.98 | 0.976 |
| 14 | 1.0 | 0.81 | 0.877 |
| 15 | 1.0 | 1.0 | 0.983 |
| 16 | 0.64 | 0.45 | 0.44 |
| 17 | 0.77 | 0.64 | 0.547 |
| 18 | 0.85 | 0.71 | 0.586 |
| 19 | 0.24 | 0.07 | 0.104 |
| 20 | 1.0 | 1.0 | 0.996 |
単一のタスクの結果は、単一のタスクモデルの10の繰り返しトレイルからのものです。20のタスクはすべて、異なるランダム初期化を伴います。各タスクの検証精度が最も低いモデルのパフォーマンスは、上の表に示されています。
共同トレーニングの結果は、共同モデルの10の繰り返しトレイルからのすべてのタスクがあります。検証精度がほとんどのタスク(> = 0.95)に合格した単一モデルのパフォーマンスは、上の表(over_scores_run2.csv)に示されています。 10回すべての実行のスコアは、結果/ディレクトリにあります。