Pytorch -Implementierung des Reptilienalgorithmus von OpenAI für überwachtes Lernen.
Derzeit läuft es auf Omniglot, aber noch nicht auf Miniimagenet.
Der Code wurde nicht ausführlich getestet. Beiträge und Feedback sind mehr als willkommen!
Es gibt bereits eine Omniglot-Datensatzklasse in Torchvision, es scheint jedoch mehr für das Lernen des überwachten Lernens angepasst zu werden als nur wenige Lernen.
Der omniglot.py bietet eine Möglichkeit, K-Shot-N-Wege-Basis-Tasks aus Omniglot und verschiedene Dienstprogramme für die Spaltung von Meta-Trainingssätzen sowie auf Basisaufgaben zu probieren.
Laden Sie die beiden Teile des Omniglot -Datensatzes herunter:
Erstellen Sie einen omniglot/ Ordner im Repo, entpacken Sie die beiden Dateien, um die folgende Ordnerstruktur zu erhalten:
./train_omniglot.py
...
./omniglot/Alphabet_of_the_Magi/
./omniglot/Angelic/
./omniglot/Anglo-Saxon_Futhorc/
...
./omniglot/ULOG/
Beginnen Sie jetzt mit dem Training mit
python train_omniglot.py log --cuda 0 $HYPERPARAMETERS # with CPU
python train_omniglot.py log $HYPERPARAMETERS # with CUDA
Wo $ hyperparameter von Ihrer Aufgabe und Ihren Hyperparametern abhängt.
Verhalten:
log/ keine Checkpoints gefunden werden, wird ein log/ Ordner zum Speichern von Tensorboard -Informationen und -Kontrechnungen erstellt.log/ gefunden werden, wird dies vom letzten Kontrollpunkt fortgesetzt. Das Training kann jederzeit mit ^C unterbrochen und aus dem letzten Kontrollpunkt wieder aufgenommen werden, indem derselbe Befehl wieder auftritt.
Der folgende Satz von Hyperparametern funktioniert anständig. Sie stammen aus der OpenAI-Implementierung, sind jedoch für meta-batch=1 leicht angepasst.


Für 5-Wege 5-Shot (rote Kurve):
python train_omniglot.py log/o55 --classes 5 --shots 5 --train-shots 10 --meta-iterations 100000 --iterations 5 --test-iterations 50 --batch 10 --meta-lr 0.2 --lr 0.001Für 5-Wege 1-Shot (blaue Kurve):
python train_omniglot.py log/o51 --classes 5 --shots 1 --train-shots 12 --meta-iterations 200000 --iterations 12 --test-iterations 86 --batch 10 --meta-lr 0.33 --lr 0.00044