El objetivo de este repositorio es contener código limpio, legible y probado para reproducir la investigación de aprendizaje de pocos disparos.
Este proyecto está escrito en Python 3.6 y Pytorch y asume que tiene una GPU.
Consulte estos artículos medios para obtener más información.
Listado en requirements.txt . Instale con pip install -r requirements.txt preferiblemente en VirtualEnv.
Edite la variable DATA_PATH en config.py en la ubicación donde almacena los conjuntos de datos OmnigLot y MiniImagenet.
Después de adquirir los datos y ejecutar los scripts de configuración, su estructura de carpeta debería verse
DATA_PATH/
Omniglot/
images_background/
images_evaluation/
miniImageNet/
images_background/
images_evaluation/
Omniglot DataSet. Descargar desde https://github.com/brendenlake/omniglot/tree/master/python, coloque los archivos extraídos en DATA_PATH/Omniglot_Raw y ejecute scripts/prepare_omniglot.py
Miniimagenet DataSet. Descargar archivos de https://drive.google.com/file/d/0b3irx3uqnobmq1flnxjszudywee/view, colocar en data/miniImageNet/images y ejecutar scripts/prepare_mini_imagenet.py
Después de agregar los conjuntos de datos, ejecuten pytest en el directorio raíz para ejecutar todas las pruebas.
El archivo experiments/experiments.txt contiene los hiperparámetros que utilicé para obtener los resultados que se dan a continuación.

Ejecute experiments/proto_nets.py para reproducir los resultados de las redes prototíticas para el aprendizaje de pocos disparos (Snell et al).
Argumentos
| Omniglot | ||||
|---|---|---|---|---|
| K-way | 5 | 5 | 20 | 20 |
| n-shot | 1 | 5 | 1 | 5 |
| Publicado | 98.8 | 99.7 | 96.0 | 98.9 |
| Este repositorio | 98.2 | 99.4 | 95.8 | 98.6 |
| miniimagenet | ||
|---|---|---|
| K-way | 5 | 5 |
| n-shot | 1 | 5 |
| Publicado | 49.4 | 68.2 |
| Este repositorio | 48.0 | 66.2 |
Un clasificador de vecinos más cercanos diferenciables.

Ejecute experiments/matching_nets.py para reproducir los resultados de las redes coincidentes para el aprendizaje de un disparo (vinyals et al).
Argumentos
Tuve problemas para reproducir los resultados de este documento utilizando la métrica de distancia del coseno, ya que encontré que la converja es lenta y el rendimiento final depende de la inicialización aleatoria. Sin embargo, pude reproducir (y exceder ligeramente) los resultados de este documento utilizando la métrica de distancia L2.
| Omniglot | ||||
|---|---|---|---|---|
| K-way | 5 | 5 | 20 | 20 |
| n-shot | 1 | 5 | 1 | 5 |
| Publicado (coseno) | 98.1 | 98.9 | 93.8 | 98.5 |
| Este repositorio (coseno) | 92.0 | 93.2 | 75.6 | 77.8 |
| Este repositorio (L2) | 98.3 | 99.8 | 92.8 | 97.8 |
| miniimagenet | ||
|---|---|---|
| K-way | 5 | 5 |
| n-shot | 1 | 5 |
| Publicado (Cosine, FCE) | 44.2 | 57.0 |
| Este repositorio (coseno, fce) | 42.8 | 53.6 |
| Este repositorio (L2) | 46.0 | 58.4 |

Utilicé la agrupación máxima en lugar de las convoluciones estridadas para ser consistente con los otros documentos. Los experimentos de MiniImagenet con el segundo orden Maml me llevaron más de un día para correr.
Ejecutar experiments/maml.py para reproducir los resultados del meta-aprendizaje del modelo agnóstico (Finn et al).
Argumentos
NB: para Maml N, K y Q se fijan entre el tren y la prueba. Es posible que deba ajustar el tamaño de meta-lote para que se ajuste a su GPU. Segundo orden Maml usa mucha más memoria.
| Omniglot | ||||
|---|---|---|---|---|
| K-way | 5 | 5 | 20 | 20 |
| n-shot | 1 | 5 | 1 | 5 |
| Publicado | 98.7 | 99.9 | 95.8 | 98.9 |
| Este repositorio (1) | 95.5 | 99.5 | 92.2 | 97.7 |
| Este repositorio (2) | 98.1 | 99.8 | 91.6 | 95.9 |
| miniimagenet | ||
|---|---|---|
| K-way | 5 | 5 |
| n-shot | 1 | 5 |
| Publicado | 48.1 | 63.2 |
| Este repositorio (1) | 46.4 | 63.3 |
| Este repositorio (2) | 47.5 | 64.7 |
El número en los soportes indica el primer o segundo orden MAML.