Implementação DMN+ em Pytorch para resposta a perguntas no conjunto de dados Babi 10K.
| arquivo | descrição |
|---|---|
babi_loader.py | Declaração de Babi Pytorch DataSet Class |
babi_main.py | Contém o modelo DMN+ e o código de treinamento |
fetch_data.sh | script de shell para buscar tarefas de babi (de DMNs em Theano) |
Instale o Pytorch v0.1.12 e o Python 3.6.x (para interpolação literal de cordas)
Execute o script de shell incluído para buscar os dados
chmod +x fetch_data.sh
./fetch_data.sh
Execute o código Python principal
python babi_main.py
As baixas precisões em comparação com Xiong et al são podem ser devido à configuração de decaimento de peso diferentes ou à instabilidade do modelo.
Em algumas tarefas, a precisão não foi estável em várias execuções. Isso foi particularmente problemático no QA3, QA17 e QA18. Para resolver isso, repetimos o treinamento 10 vezes usando inicializações aleatórias e avaliamos o modelo que alcançou a menor perda de conjunto de validação.
Você pode encontrar modelos pré -traidos aqui
| ID da tarefa | Este repo | Xiong et al |
|---|---|---|
| 1 | 100% | 100% |
| 2 | 96,8% | 99,7% |
| 3 | 89,2% | 98,9% |
| 4 | 100% | 100% |
| 5 | 99,5% | 99,5% |
| 6 | 100% | 100% |
| 7 | 97,8% | 97,6% |
| 8 | 100% | 100% |
| 9 | 100% | 100% |
| 10 | 100% | 100% |
| 11 | 100% | 100% |
| 12 | 100% | 100% |
| 13 | 100% | 100% |
| 14 | 99% | 99,8% |
| 15 | 100% | 100% |
| 16 | 51,6% | 54,7% |
| 17 | 86,4% | 95,8% |
| 18 | 97,9% | 97,9% |
| 19 | 99,7% | 100% |
| 20 | 100% | 100% |