Par Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby
Mise à jour 18/06/2021: Nous publions de nouveaux modèles Bit-R50x1 haut performantes, qui ont été distillés à partir de Bit-M-R152X2, voir cette section. Plus de détails dans notre article "Distillation des connaissances: un bon enseignant est patient et cohérent".
Mise à jour 08/02/2021: Nous publions également tous les modèles Bit-M affinés sur les 19 ensembles de données VTAB-1k, voir ci-dessous.
Dans ce référentiel, nous publions plusieurs modèles à partir du Big Transfer (BIT): document d'apprentissage de la représentation visuelle générale qui ont été formés sur les ensembles de données ILSVRC-2012 et ImageNet-21k. Nous fournissons le code pour affiner les modèles publiés dans les principaux cadres d'apprentissage en profondeur TensorFlow 2, Pytorch et Jax / Flax.
Nous espérons que la communauté de la vision par ordinateur bénéficiera en utilisant des modèles ImageNet-21k plus puissants, par opposition aux modèles conventionnels pré-formés sur l'ensemble de données ILSVRC-2012.
Nous fournissons également des colabs pour une utilisation interactive plus exploratoire: un colab Tensorflow 2, un colab pytorch et un jax colab.
Assurez-vous que Python>=3.6 installé sur votre machine.
Pour configurer TensorFlow 2, Pytorch ou Jax, suivez les instructions fournies dans le référentiel correspondant lié ici.
De plus, installez les dépendances Python en fonctionnant (veuillez sélectionner tf2 , pytorch ou jax dans la commande ci-dessous):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
Tout d'abord, téléchargez le modèle Bit. Nous fournissons des modèles pré-formés sur ILSVRC-2012 (BIT-S) ou IMAMENET-21K (BIT-M) pour 5 architectures différentes: RESNET-50X1, RESNET-101X1, RESNET-50X3, RESNET-101X3 et RESNET-152X4.
Par exemple, si vous souhaitez télécharger le Resnet-50x1 pré-formé sur ImageNet-21k, exécutez la commande suivante:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
D'autres modèles peuvent être téléchargés en conséquence en branchant le nom du modèle (Bit-S ou Bit-M) et de l'architecture dans la commande ci-dessus. Notez que nous fournissons des modèles dans deux formats: npz (pour Pytorch et Jax) et h5 (pour TF2). Par défaut, nous nous attendons à ce que les poids du modèle soient stockés dans le dossier racine de ce référentiel.
Ensuite, vous pouvez exécuter le réglage fin du modèle téléchargé sur votre ensemble de données d'intérêt dans l'un des trois cadres. Tous les cadres partagent l'interface de ligne de commande
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
Actuellement. Tous les cadres téléchargeront automatiquement les ensembles de données CIFAR-10 et CIFAR-100. D'autres ensembles de données publics ou personnalisés peuvent être facilement intégrés: dans TF2 et JAX, nous nous appuyons sur la bibliothèque de jeux de données TensorFlow extensible. Dans Pytorch, nous utilisons le pipeline d'entrée de données de TorchVision.
Notez que notre code utilise tous les GPU disponibles pour le réglage fin.
Nous soutenons également la formation dans le régime de faible données: l'option --examples_per_class <K> dessinera au hasard des échantillons k par classe pour la formation.
Pour voir une liste détaillée de tous les drapeaux disponibles, exécutez python3 -m bit_{pytorch|jax|tf2}.train --help .
Pour plus de commodité, nous fournissons des modèles Bit-M qui étaient déjà affinés sur l'ensemble de données ILSVRC-2012. Les modèles peuvent être téléchargés en ajoutant le postfix -ILSVRC2012 , par exemple
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
Nous libérons toutes les architectures mentionnées dans le document, de sorte que vous pouvez choisir entre la précision ou la vitesse: R50X1, R101X1, R50X3, R101X3, R152X4. Dans le chemin ci-dessus vers le fichier du modèle, remplacez simplement R50x1 par votre architecture de choix.
Nous avons en outre enquêté sur plus d'architectures après la publication du journal et avons constaté que R152X2 avait un bon compromis entre la vitesse et la précision, nous incluons donc cela dans le communiqué et fournissons quelques chiffres ci-dessous.
Nous publions également les modèles affinés pour chacune des 19 tâches incluses dans la référence VTAB-1k. Nous avons exécuté chaque modèle trois fois et libérons chacune de ces courses. Cela signifie que nous libérons un total de 5x19x3 = 285 modèles, et espérons que ceux-ci pourront être utiles dans une analyse plus approfondie de l'apprentissage du transfert.
Les fichiers peuvent être téléchargés via le modèle suivant:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
Nous n'avons pas converti ces modèles en TF2 (par conséquent, il n'y a pas de fichier .h5 correspondant), cependant, nous avons également téléchargé des modèles TFHUB qui peuvent être utilisés dans TF1 et TF2. Un exemple de séquence de commandes pour télécharger un de ces modèles est:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
Pour la reproductibilité, notre script d'entraînement utilise des hyper-paramètres (bit-hyperrule) qui ont été utilisés dans l'article d'origine. Remarque, cependant, que les modèles de bits ont été formés et finetunés à l'aide du matériel TPU Cloud, donc pour une configuration GPU typique, nos hyper-paramètres par défaut pourraient nécessiter trop de mémoire ou entraîner une progression très lente. De plus, Bit-Hyperrule est conçu pour généraliser dans de nombreux ensembles de données, il est donc généralement possible de concevoir des hyper-paramètres spécifiques à l'application plus efficaces. Ainsi, nous encourageons l'utilisateur à essayer des paramètres plus légers, car ils nécessitent beaucoup moins de ressources et entraînent souvent une précision similaire.
Par exemple, nous avons testé notre code à l'aide d'une machine GPU 8xv100 sur les ensembles de données CIFAR-10 et CIFAR-100, tout en réduisant la taille du lot de 512 à 128 et le taux d'apprentissage de 0,003 à 0,001. Cette configuration a entraîné des performances presque identiques (voir les résultats attendus ci-dessous) par rapport à l'hyperrule bit, bien qu'il soit moins exigeant par calcul.
Ci-dessous, nous fournissons plus de suggestions sur la façon d'optimiser la configuration de notre article.
Le bit-hyperrule par défaut a été développé sur les TPU cloud et est assez avalé. Cela est principalement dû à la grande taille par lots (512) et à la résolution d'image (jusqu'à 480x480). Voici quelques conseils si vous manquez de mémoire:
bit_hyperrule.py nous spécifions la résolution d'entrée. En le réduisant, on peut économiser beaucoup de mémoire et de calcul, au détriment de la précision.--batch_split . Par exemple, l'exécution du réglage fin avec --batch_split 8 réduit les besoins en mémoire d'un facteur 8. Nous avons vérifié que lors de l'utilisation du bit-hyperrule, le code de ce référentiel reproduit les résultats du papier.
Pour ces repères communs, les modifications susmentionnées dans le bit-hyperrule ( --batch 128 --base_lr 0.001 ) conduisent aux résultats très similaires suivants. Le tableau montre la médiane min ← → Max Résultat d'au moins cinq courses. Remarque : Ce n'est pas une comparaison des cadres, il suffit de prouver que tous les bases de code peuvent être fiables pour reproduire les résultats.
| Ensemble de données | Ex / CLS | Tf2 | Jax | Pytorch |
|---|---|---|---|---|
| Cifar10 | 1 | 52,5 ← 55,8 → 60,2 | 48,7 ← 53,9 → 65.0 | 56,4 ← 56,7 → 73.1 |
| Cifar10 | 5 | 85,3 ← 87.2 → 89.1 | 80.2 ← 85.8 → 88.6 | 84,8 ← 85,8 → 89.6 |
| Cifar10 | complet | 98.5 | 98.4 | 98,5 ← 98,6 → 98,6 |
| CIFAR100 | 1 | 34,8 ← 35,7 → 37.9 | 32.1 ← 35.0 → 37.1 | 31,6 ← 33,8 → 36,9 |
| CIFAR100 | 5 | 68,8 ← 70,4 → 71.4 | 68,6 ← 70,8 → 71.6 | 70,6 ← 71,6 → 71.7 |
| CIFAR100 | complet | 90.8 | 91.2 | 91.1 ← 91.2 → 91.4 |
| Ensemble de données | Ex / CLS | Jax | Pytorch |
|---|---|---|---|
| Cifar10 | 1 | 44.0 ← 56,7 → 65.0 | 50,9 ← 55,5 → 59,5 |
| Cifar10 | 5 | 85,3 ← 87.0 → 88.2 | 85,3 ← 85,8 → 88.6 |
| Cifar10 | complet | 98.5 | 98,5 ← 98,5 → 98,6 |
| CIFAR100 | 1 | 36,4 ← 37,2 → 38.9 | 34.3 ← 36,8 → 39.0 |
| CIFAR100 | 5 | 69.3 ← 70,5 → 72.0 | 70,3 ← 72.0 → 72.3 |
| CIFAR100 | complet | 91.2 | 91.2 ← 91.3 → 91.4 |
(Modèles TF2 pas encore disponibles.)
| Ensemble de données | Ex / CLS | Tf2 | Jax | Pytorch |
|---|---|---|---|---|
| Cifar10 | 1 | 49,9 ← 54,4 → 60.2 | 48,4 ← 54.1 → 66.1 | 45,8 ← 57,9 → 65.7 |
| Cifar10 | 5 | 80,8 ← 83,3 → 85,5 | 76,7 ← 82,4 → 85,4 | 80,3 ← 82.3 → 84.9 |
| Cifar10 | complet | 97.2 | 97.3 | 97.4 |
| CIFAR100 | 1 | 35.3 ← 37.1 → 38.2 | 32.0 ← 35.2 → 37.8 | 34,6 ← 35.2 → 38.6 |
| CIFAR100 | 5 | 63,8 ← 65.0 → 66.5 | 63.4 ← 64.8 → 66.5 | 64,7 ← 65,5 → 66.0 |
| CIFAR100 | complet | 86.5 | 86.4 | 86.6 |
Ces résultats ont été obtenus en utilisant un bit-hyperrule. Cependant, comme cela se traduit par une grande taille par lots et une grande résolution, la mémoire peut être un problème. Le code Pytorch prend en charge le coup de pouce, et nous pouvons donc y exécuter des choses sans recourir aux TPU cloud en ajoutant la commande --batch_split N où N est une puissance de deux. Par exemple, la commande suivante produit une précision de validation de 80.68 sur une machine avec 8 GPU V100:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
Augmentez-vous à --batch_split 8 lors de l'exécution avec 4 GPU V100, etc.
Les résultats complets obtenus de cette façon dans certains essais étaient:
| Ex / CLS | R50X1 | R152X2 | R101x3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25,55 |
| 5 | 50,64 | 64.5 | 64.18 |
| complet | 80.68 | 85.15 | Vider |
Ce sont des rediffusions et non les modèles papier exacts. Les scores VTAB attendus pour deux des modèles sont:
| Modèle | Complet | Naturel | Structuré | Spécialisé |
|---|---|---|---|---|
| Bit-m-r152x4 | 73.51 | 80.77 | 61.08 | 85,67 |
| Bit-m-r101x3 | 72.65 | 80.29 | 59.40 | 85,75 |
Dans l'annexe G de notre article, nous examinons si BIT améliore la robustesse hors contexte. Pour ce faire, nous avons créé un ensemble de données comprenant des objets de premier plan correspondant à 21 classes ILSVRC-2012 collées sur 41 arrière-plans divers.
Pour télécharger l'ensemble de données, exécutez
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
Les images de chacune des 21 classes sont conservées dans un répertoire avec le nom de la classe.
Nous publions des modèles de bits compressés les plus performants à partir de notre document "Distillation de connaissances: un bon enseignant est patient et cohérent" sur la distillation de Knowelldge. En particulier, nous distillons le modèle BIT-M-R152X2 (qui a été pré-formé sur ImageNet-21K) aux modèles Bit-R50x1. En conséquence, nous obtenons des modèles compacts avec des performances très compétitives.
| Modèle | Lien de téléchargement | Résolution | ImageNet Top-1 ACC. (papier) |
|---|---|---|---|
| Bit-r50x1 | lien | 224 | 82.8 |
| Bit-r50x1 | lien | 160 | 80.5 |
Pour la reproductibilité, nous libérons également des poids de deux modèles d'enseignants Bit-M-R152X2: pré-entraîné à la résolution 224 et la résolution 384. Voir le document pour plus de détails sur la façon dont ces enseignants ont été utilisés.
Nous n'avons pas de plans concrets pour publier le code de distillation, car la recette est simple et nous imaginons que la plupart des gens l'intégreraient dans leur code de formation existant. Cependant, Sayak Paul a réimplémenté indépendamment la configuration de la distillation dans TensorFlow et a presque reproduit nos résultats dans plusieurs contextes.