Pytorch Implémentation de l'algorithme de reptile d'Openai pour l'apprentissage supervisé.
Actuellement, il fonctionne sur omniglot mais pas encore sur MiniImagenet.
Le code n'a pas été beaucoup testé. Les contributions et les commentaires sont plus que les bienvenus!
Il existe déjà une classe de jeu de données Omniglot dans TorchVision, mais il semble être plus adapté pour l'apprentissage supervisé que l'apprentissage à quelques coups.
L' omniglot.py fournit un moyen d'échantillonner les tâches de base N-Way K-Way à partir d'Omniglot, et divers utilitaires pour diviser les ensembles de formation de méta ainsi que les tâches de base.
Téléchargez les deux parties de l'ensemble de données Omniglot:
Créez un dossier omniglot/ dans le dépôt, dézip et fusionnez les deux fichiers pour avoir la structure du dossier suivant:
./train_omniglot.py
...
./omniglot/Alphabet_of_the_Magi/
./omniglot/Angelic/
./omniglot/Anglo-Saxon_Futhorc/
...
./omniglot/ULOG/
Maintenant, commencez à s'entraîner avec
python train_omniglot.py log --cuda 0 $HYPERPARAMETERS # with CPU
python train_omniglot.py log $HYPERPARAMETERS # with CUDA
où $ hyperparamètres dépend de votre tâche et de vos hyperparamètres.
Comportement:
log/ , cela créera un log/ dossier pour stocker les informations et les points de contrôle Tensorboard.log/ , cela reprendra à partir du dernier point de contrôle. La formation peut être interrompue à tout moment avec ^C et repris du dernier point de contrôle en réduisant la même commande.
L'ensemble suivant d'hyperparamètres fonctionne décemment. Ils sont tirés de l'implémentation OpenAI mais sont légèrement adaptés pour meta-batch=1 .


Pour 5 voies 5-Shot (courbe rouge):
python train_omniglot.py log/o55 --classes 5 --shots 5 --train-shots 10 --meta-iterations 100000 --iterations 5 --test-iterations 50 --batch 10 --meta-lr 0.2 --lr 0.001Pour 5 voies 1-Shot (courbe bleue):
python train_omniglot.py log/o51 --classes 5 --shots 1 --train-shots 12 --meta-iterations 200000 --iterations 12 --test-iterations 86 --batch 10 --meta-lr 0.33 --lr 0.00044