Dynamic memory networks plus Pytorch
1.0.0
BABI 10K 데이터 세트에 대한 질문에 대한 질문에 대한 Pytorch의 DMN+ 구현.
| 파일 | 설명 |
|---|---|
babi_loader.py | Babi Pytorch 데이터 세트 클래스 선언 |
babi_main.py | DMN+ 모델 및 교육 코드가 포함되어 있습니다 |
fetch_data.sh | Babi 작업을 가져 오는 쉘 스크립트 (Theano의 DMNS에서) |
Pytorch v0.1.12 및 Python 3.6.x 설치 (문자 문자열 보간 용)
포함 된 쉘 스크립트를 실행하여 데이터를 가져 오십시오
chmod +x fetch_data.sh
./fetch_data.sh
기본 파이썬 코드를 실행하십시오
python babi_main.py
Xiong 등과 비교하여 낮은 정확도는 다른 중량 붕괴 설정 또는 모델의 불안정성으로 인해 발생할 수 있습니다.
일부 작업에서는 여러 실행에 걸쳐 정확도가 안정적이지 않았습니다. 이는 QA3, QA17 및 QA18에서 특히 문제가되었습니다. 이를 해결하기 위해 무작위 초기화를 사용하여 10 번 훈련을 반복하고 가장 낮은 검증 세트 손실을 달성 한 모델을 평가했습니다.
여기에서 사전에 사전 된 모델을 찾을 수 있습니다
| 작업 ID | 이 repo | Xiong et al |
|---|---|---|
| 1 | 100% | 100% |
| 2 | 96.8% | 99.7% |
| 3 | 89.2% | 98.9% |
| 4 | 100% | 100% |
| 5 | 99.5% | 99.5% |
| 6 | 100% | 100% |
| 7 | 97.8% | 97.6% |
| 8 | 100% | 100% |
| 9 | 100% | 100% |
| 10 | 100% | 100% |
| 11 | 100% | 100% |
| 12 | 100% | 100% |
| 13 | 100% | 100% |
| 14 | 99% | 99.8% |
| 15 | 100% | 100% |
| 16 | 51.6% | 54.7% |
| 17 | 86.4% | 95.8% |
| 18 | 97.9% | 97.9% |
| 19 | 99.7% | 100% |
| 20 | 100% | 100% |