Dynamic memory networks plus Pytorch
1.0.0
PytorchでのDMN+実装は、BABI 10Kデータセットでの質問に答えます。
| ファイル | 説明 |
|---|---|
babi_loader.py | Babi Pytorch Datasetクラスの宣言 |
babi_main.py | DMN+モデルとトレーニングコードが含まれています |
fetch_data.sh | バビのタスクを取得するためのシェルスクリプト(TheanoのDMNSから) |
Pytorch V0.1.12とPython 3.6.x(リテラルストリング補間用)をインストールする
含まれているシェルスクリプトを実行して、データを取得します
chmod +x fetch_data.sh
./fetch_data.sh
メインのPythonコードを実行します
python babi_main.py
Xiong et alと比較した低精度は、異なる重量減衰設定またはモデルの不安定性による可能性があります。
一部のタスクでは、複数の実行にわたって精度は安定していませんでした。これは、QA3、QA17、およびQA18で特に問題がありました。これを解決するために、ランダムな初期化を使用して10回トレーニングを繰り返し、最低検証セットの損失を達成したモデルを評価しました。
ここで前払いされたモデルを見つけることができます
| タスクID | このレポ | Xiong et al |
|---|---|---|
| 1 | 100% | 100% |
| 2 | 96.8% | 99.7% |
| 3 | 89.2% | 98.9% |
| 4 | 100% | 100% |
| 5 | 99.5% | 99.5% |
| 6 | 100% | 100% |
| 7 | 97.8% | 97.6% |
| 8 | 100% | 100% |
| 9 | 100% | 100% |
| 10 | 100% | 100% |
| 11 | 100% | 100% |
| 12 | 100% | 100% |
| 13 | 100% | 100% |
| 14 | 99% | 99.8% |
| 15 | 100% | 100% |
| 16 | 51.6% | 54.7% |
| 17 | 86.4% | 95.8% |
| 18 | 97.9% | 97.9% |
| 19 | 99.7% | 100% |
| 20 | 100% | 100% |