von Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby
UPDATE 18/06/2021: Wir veröffentlichen neue leistungsstarke Bit-R50x1-Modelle, die von Bit-M-R152X2 destilliert wurden, siehe Abschnitt. Weitere Details in unserem Artikel "Wissensdestillation: Ein guter Lehrer ist geduldig und konsequent".
Update 08/02/2021: Wir veröffentlichen auch alle Bit-M-Modelle, die auf allen 19 VTAB-1K-Datensätzen fein abgestimmt sind, siehe unten.
In diesem Repository veröffentlichen wir mehrere Modelle aus der Big Transfer (BIT): Allgemeines Lernpapier für visuelle Darstellungen, die auf den Datensätzen ILSVRC-2012 und ImageNet-21K vorgebracht wurden. Wir bieten den Code zur Feinabstimmung der freigegebenen Modelle im Major Deep Learning Frameworks TensorFlow 2, Pytorch und Jax/Flachs.
Wir hoffen, dass die Computer-Vision-Community davon profitieren wird, indem sie leistungsfähigere imageNet-21k-vorbereitete Modelle im Gegensatz zu herkömmlichen Modellen verwenden, die im Datensatz ILSVRC-2012 vorgeschrieben sind.
Wir bieten auch Colabs für eine explorativere interaktivere Verwendung an: einen Tensorflow 2 Colab, einen Pytorch Colab und einen Jax Colab.
Stellen Sie sicher, dass Sie Python>=3.6 auf Ihrem Computer installiert haben.
Um TensorFlow 2, Pytorch oder JAX einzurichten, befolgen Sie die Anweisungen, die im hier verknüpften entsprechenden Repository angegeben sind.
Installieren Sie außerdem Python -Abhängigkeiten durch Ausführen (Wählen Sie im folgenden Befehl bitte tf2 , pytorch oder jax aus):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
Laden Sie zunächst das Bit -Modell herunter. Wir stellen Modelle zur Verfügung, die auf ILSVRC-2012 (BIT-S) oder ImageNet-21K (BIT-M) für 5 verschiedene Architekturen vorhanden sind: Resnet-50x1, Resnet-101x1, Resnet-50x3, Resnet-101x3 und Resnet-152x4.
Wenn Sie beispielsweise den auf ImageNet-21K vorgebrachten ResNet-50x1-Vorausgeblieben herunterladen möchten, führen Sie den folgenden Befehl aus:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
Andere Modelle können entsprechend heruntergeladen werden, indem der Name des Modells (Bit-S oder BIT-M) und die Architektur im obigen Befehl angeschlossen wird. Beachten Sie, dass wir Modelle in zwei Formaten bereitstellen: npz (für Pytorch und Jax) und h5 (für TF2). Standardmäßig erwarten wir, dass Modellgewichte im Stammordner dieses Repositorys gespeichert werden.
Anschließend können Sie das heruntergeladene Modell in Ihrem Datensatz in einem der drei Frameworks ausführen. Alle Frameworks teilen die Befehlszeilenschnittstelle
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
Momentan. Alle Frameworks laden automatisch CIFAR-10- und CIFAR-100-Datensätze herunter. Andere öffentliche oder benutzerdefinierte Datensätze können leicht integriert werden: In TF2 und JAX verlassen wir uns auf die erweiterbare TensorFlow -Datensätze -Bibliothek. In Pytorch verwenden wir die Dateneingangspipeline von TorChvision.
Beachten Sie, dass unser Code alle verfügbaren GPUs zur Feinabstimmung verwendet.
Wir unterstützen auch das Training im Regime mit niedrigem Daten: Die Option- --examples_per_class <K> zeichnet zufällig K-Proben pro Klasse für das Training.
Um eine detaillierte Liste aller verfügbaren Flags anzuzeigen, führen Sie python3 -m bit_{pytorch|jax|tf2}.train --help aus.
Für den Einfachheit halber stellen wir BIT-M-Modelle bereit, die bereits im ILSVRC-2012-Datensatz fein abgestimmt wurden. Die Modelle können heruntergeladen werden, indem das postfix -ILSVRC2012 -Postfix hinzugefügt wird, z. B.
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
Wir geben alle im Papier genannten Architekturen frei, sodass Sie zwischen Genauigkeit oder Geschwindigkeit wählen können: R50x1, R101x1, R50x3, R101x3, R152x4. Ersetzen Sie im obigen Pfad zur Modelldatei einfach R50x1 durch Ihre Auswahlarchitektur.
Nach der Veröffentlichung des Papiers untersuchten wir weitere Architekturen und stellten fest, dass R152x2 einen schönen Kompromiss zwischen Geschwindigkeit und Genauigkeit erzielt. Daher fügen wir dies auch in die Veröffentlichung ein und geben einige unten einige Zahlen an.
Wir veröffentlichen auch die fein abgestimmten Modelle für jede der 19 in der VTAB-1K-Benchmark enthaltenen Aufgaben. Wir haben jedes Modell dreimal ausgeführt und jede dieser Läufe veröffentlicht. Dies bedeutet, dass wir insgesamt 5x19x3 = 285 Modelle veröffentlichen und hoffen, dass diese bei der weiteren Analyse des Transferlernens nützlich sein können.
Die Dateien können über das folgende Muster heruntergeladen werden:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
Wir haben diese Modelle nicht in TF2 konvertiert (daher gibt es keine entsprechende .h5 -Datei), aber wir haben auch TFHUB -Modelle hochgeladen, die in TF1 und TF2 verwendet werden können. Eine Beispielsequenz von Befehlen zum Herunterladen eines solchen Modells ist:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
Für die Reproduzierbarkeit verwendet unser Trainingsskript Hyper-Parameter (Bit-Hyperrule), die im Originalpapier verwendet wurden. Beachten Sie jedoch, dass Bit-Modelle mithilfe von Cloud-TPU-Hardware trainiert und finetuniert wurden. Für eine typische GPU-Einrichtung können unsere Standard-Hyperparameter zu viel Speicher erfordern oder zu einem sehr langsamen Fortschritt führen. Darüber hinaus ist Bit-HyperRule so konzipiert, dass sie über viele Datensätze hinweg verallgemeinert werden. Daher ist es in der Regel möglich, effizientere anwendungsspezifische Hyperparameter zu entwickeln. Daher ermutigen wir den Benutzer, mehr leichte Einstellungen auszuprobieren, da er viel weniger Ressourcen erfordern und häufig zu einer ähnlichen Genauigkeit führen.
Zum Beispiel haben wir unseren Code mit einer 8xv100-GPU-Maschine auf den Datensätzen CIFAR-10 und CIFAR-100 getestet, während wir die Stapelgröße von 512 auf 128 und die Lernrate von 0,003 auf 0,001 reduzierten. Dieses Setup führte zu einer nahezu identischen Leistung (siehe erwartete Ergebnisse unten) im Vergleich zu Bithyperrule, obwohl sie weniger rechnerisch anspruchsvoll waren.
Im Folgenden geben wir weitere Vorschläge zur Optimierung des Setups unserer Arbeit.
Die Standard-Bithyperrule wurde auf Cloud-TPUs entwickelt und ist ziemlich speicherhungry. Dies ist hauptsächlich auf die große Chargengröße (512) und die Bildauflösung (bis zu 480 x 480) zurückzuführen. Hier sind einige Tipps, wenn Ihnen der Speicher ausgeht:
bit_hyperrule.py geben wir die Eingangsauflösung an. Durch die Reduzierung kann man auf Kosten der Genauigkeit viel Speicher und Berechnung sparen.--batch_split Option. Wenn Sie beispielsweise die Feinabstimmung mit --batch_split 8 durchführen, reduziert die Speicheranforderung um den Faktor 8. Wir haben überprüft, dass der Code in diesem Repository die Ergebnisse des Papiers reproduziert, wenn der Bithyperrule verwendet wird.
Für diese gängigen Benchmarks führen die oben genannten Änderungen am Bithyperrule ( --batch 128 --base_lr 0.001 ) zu den folgenden, sehr ähnlichen Ergebnissen. Die Tabelle zeigt das min ← Median → Max -Ergebnis von mindestens fünf Läufen. Hinweis : Dies ist kein Vergleich von Frameworks, sondern nur Beweise dafür, dass alle Codebasen vertrauen können, um die Ergebnisse zu reproduzieren.
| Datensatz | Ex/cls | Tf2 | Jax | Pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 52,5 ← 55.8 → 60.2 | 48.7 ← 53.9 → 65.0 | 56.4 ← 56.7 → 73.1 |
| CIFAR10 | 5 | 85.3 ← 87.2 → 89.1 | 80.2 ← 85.8 → 88.6 | 84.8 ← 85.8 → 89.6 |
| CIFAR10 | voll | 98,5 | 98.4 | 98.5 ← 98.6 → 98.6 |
| CIFAR100 | 1 | 34.8 ← 35.7 → 37.9 | 32.1 ← 35.0 → 37.1 | 31.6 ← 33.8 → 36.9 |
| CIFAR100 | 5 | 68.8 ← 70.4 → 71.4 | 68.6 ← 70.8 → 71.6 | 70.6 ← 71.6 → 71.7 |
| CIFAR100 | voll | 90,8 | 91.2 | 91.1 ← 91.2 → 91.4 |
| Datensatz | Ex/cls | Jax | Pytorch |
|---|---|---|---|
| CIFAR10 | 1 | 44.0 ← 56.7 → 65.0 | 50.9 ← 55.5 → 59.5 |
| CIFAR10 | 5 | 85.3 ← 87.0 → 88.2 | 85.3 ← 85.8 → 88.6 |
| CIFAR10 | voll | 98,5 | 98.5 ← 98.5 → 98.6 |
| CIFAR100 | 1 | 36.4 ← 37.2 → 38.9 | 34.3 ← 36.8 → 39.0 |
| CIFAR100 | 5 | 69.3 ← 70.5 → 72.0 | 70.3 ← 72.0 → 72.3 |
| CIFAR100 | voll | 91.2 | 91.2 ← 91.3 → 91.4 |
(TF2 -Modelle noch nicht verfügbar.)
| Datensatz | Ex/cls | Tf2 | Jax | Pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 49,9 ← 54.4 → 60.2 | 48.4 ← 54.1 → 66.1 | 45.8 ← 57.9 → 65.7 |
| CIFAR10 | 5 | 80.8 ← 83.3 → 85.5 | 76.7 ← 82.4 → 85.4 | 80.3 ← 82.3 → 84.9 |
| CIFAR10 | voll | 97,2 | 97.3 | 97,4 |
| CIFAR100 | 1 | 35.3 ← 37.1 → 38.2 | 32.0 ← 35.2 → 37.8 | 34.6 ← 35.2 → 38.6 |
| CIFAR100 | 5 | 63.8 ← 65.0 → 66.5 | 63.4 ← 64.8 → 66.5 | 64.7 ← 65.5 → 66.0 |
| CIFAR100 | voll | 86,5 | 86,4 | 86.6 |
Diese Ergebnisse wurden unter Verwendung von Bithyperrule erhalten. Da dies jedoch zu einer großen Chargengröße und einer großen Auflösung führt, kann das Speicher ein Problem sein. Der Pytorch-Code unterstützt die Stapelspaltung. Daher können wir dort immer noch Dinge ausführen, ohne auf Cloud-TPUs zurückzugreifen, indem wir den Befehl --batch_split N hinzufügen, bei dem N eine Leistung von zwei ist. Zum Beispiel erzeugt der folgende Befehl eine Validierungsgenauigkeit von 80.68 auf einer Maschine mit 8 V100 -GPUs:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
Erhöhen Sie weiter auf --batch_split 8 beim Ausführen mit 4 V100 GPUs usw.
Auf diese Weise in einigen Testläufen erzielte die vollständigen Ergebnisse:
| Ex/cls | R50x1 | R152X2 | R101X3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25.55 |
| 5 | 50.64 | 64,5 | 64.18 |
| voll | 80.68 | 85.15 | Wip |
Dies sind Wiederholungen und nicht die genauen Papiermodelle. Die erwarteten VTAB -Ergebnisse für zwei der Modelle sind:
| Modell | Voll | Natürlich | Strukturiert | Spezialisiert |
|---|---|---|---|---|
| BIT-M-R152X4 | 73,51 | 80.77 | 61.08 | 85.67 |
| Bit-M-R101X3 | 72,65 | 80.29 | 59.40 | 85.75 |
In Anhang G unseres Papiers untersuchen wir, ob Bit die Robustheit außerhalb der Kontext verbessert. Zu diesem Zweck haben wir einen Datensatz erstellt, der Vordergrundobjekte umfasst, die 21 ILSVRC-2012-Klassen entsprechen, die auf 41 verschiedene Hintergründe eingefügt wurden.
Um den Datensatz herunterzuladen, laufen Sie aus
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
Bilder aus jedem der 21 Klassen werden in einem Verzeichnis mit dem Namen der Klasse aufbewahrt.
Wir geben mit der Destillation "Wissensdestillation: Ein guter Lehrer" -Papiermodelle "Knowledge Destillation: Ein guter Lehrer" bei der Knoweldge-Destillation freigesetzt. Insbesondere destillieren wir das BIT-M-R152X2-Modell (das auf ImageNet-21K vorgebracht wurde) auf Bit-R50x1-Modelle. Infolgedessen erhalten wir kompakte Modelle mit sehr wettbewerbsfähiger Leistung.
| Modell | Link herunterladen | Auflösung | ImageNet Top-1 ACC. (Papier) |
|---|---|---|---|
| Bit-R50x1 | Link | 224 | 82.8 |
| Bit-R50x1 | Link | 160 | 80.5 |
Zur Reproduzierbarkeit füllen wir auch Gewichte von zwei BIT-M-R152X2-Lehrermodellen frei: vorab bei der Auflösung 224 und der Auflösung 384. Weitere Informationen zur Verwendung dieser Lehrer.
Wir haben keine konkreten Pläne für die Veröffentlichung des Destillationscode, da das Rezept einfach ist und wir uns vorstellen, dass die meisten Menschen ihn in ihren vorhandenen Trainingscode integrieren würden. Sayak Paul hat jedoch das Destillationsaufbau in Tensorflow unabhängig voneinander implementiert und unsere Ergebnisse in mehreren Einstellungen nahezu reproduziert.