亞歷山大·科萊斯尼科夫(Alexander Kolesnikov),盧卡斯·拜爾(Lucas Beyer),Xiaohua Zhai,Joan Puigcerver,Jessica Yung,Sylvain Gelly,Neil Houlsby
更新18/06/2021:我們發布了從BIT-M-R152X2提煉的新的高性能Bit-R50x1型號,請參見本節。我們的論文中的更多細節“知識蒸餾:好的老師是耐心且一致的”。
更新08/02/2021:我們還發布了所有19個VTAB-1K數據集的所有Bit-M模型,請參見下文。
在此存儲庫中,我們從大型傳輸(位)中釋放多個模型:一般視覺表示學習紙,這些模型已在ILSVRC-2012和Imagenet-21K數據集中進行了預訓練。我們提供代碼,以微調主要的深度學習框架Tensorflow 2,Pytorch和Jax/Flax。
我們希望計算機視覺社區將通過使用更強大的Imagenet-21 K審慎模型而受益,而不是在ILSVRC-2012數據集中預先培訓的常規模型。
我們還提供了可用於更具探索性交互式用途的Colabs:Tensorflow 2 Colab,Pytorch Colab和Jax Colab。
確保在計算機上安裝了Python>=3.6 。
要設置TensorFlow 2,Pytorch或Jax,請遵循此處鏈接的相應存儲庫中提供的說明。
此外,通過運行安裝Python依賴項(請在下面的命令中選擇tf2 , pytorch或jax ):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
首先,下載位模型。我們為5種不同架構的ILSVRC-2012(BIT-S)或ImagEnet-21K(BIT-M)提供了預訓練的模型:Resnet-50x1,Resnet-101x1,Resnet-50x3,Resnet-101x3,Resnet-101x3和Resnet-152x4。
例如,如果您想在Imagenet-21K上下載預訓練的Resnet-50x1,請運行以下命令:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
可以通過在上述命令中插入模型(位-S或BIT-M)和體系結構的名稱來相應下載其他模型。請注意,我們提供兩種格式的模型: npz (用於Pytorch和Jax)和h5 (對於TF2)。默認情況下,我們希望模型權重存儲在此存儲庫的根文件夾中。
然後,您可以在三個框架中的任何一個中的任何數據集中對下載的模型進行微調。所有框架共享命令行接口
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
現在。所有框架將自動下載CIFAR-10和CIFAR-100數據集。其他公共或自定義數據集可以很容易地集成:在TF2和JAX中,我們依賴於可擴展的Tensorflow數據集庫。在Pytorch中,我們使用Torchvision的數據輸入管道。
請注意,我們的代碼使用所有可用的GPU進行微調。
我們還支持低數據制度中的培訓: --examples_per_class <K>選項將隨機繪製每個課程的K樣本進行培訓。
要查看所有可用標誌的詳細列表,請運行python3 -m bit_{pytorch|jax|tf2}.train --help 。
為了方便起見,我們提供了在ILSVRC-2012數據集上已經進行了微調的位模型。可以通過添加-ILSVRC2012 Postfix下載這些模型
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
我們發布了論文中提到的所有架構,以便您可以在準確性或速度之間進行選擇:R50x1,R101x1,R50x3,R101x3,R152x4。在上述模型文件的路徑中,只需通過選擇的體系結構替換R50x1即可。
本文出版後,我們進一步研究了更多的架構,發現R152x2在速度和準確性之間取決於良好的權衡,因此我們還將其包括在版本中,並在下面提供了一些數字。
我們還為VTAB-1K基準中包含的19個任務中的每一個都發布了微調模型。我們運行了每個型號三次,並釋放這些運行中的每一個。這意味著我們總共發布了5x19x3 = 285個模型,並希望這些模型可用於進一步分析轉移學習。
可以通過以下模式下載文件:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
我們沒有將這些模型轉換為TF2(因此沒有相應的.h5文件),但是,我們還上傳了可以在TF1和TF2中使用的TFHUB模型。下載一個這樣一個模型的命令序列是:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
為了獲得可重複性,我們的培訓腳本使用了原始論文中使用的超參數(鑽頭)。但是請注意,使用雲TPU硬件對BIT模型進行了訓練和填充,因此對於典型的GPU設置,我們默認的超參數可能需要過多的內存或導致進度非常緩慢。此外,Bit-Hyperrule旨在概括許多數據集,因此通常可以設計更有效的特定於應用程序的超參數。因此,我們鼓勵用戶嘗試更多的輕重量設置,因為它們需要更少的資源,並且通常會產生類似的精度。
例如,我們使用CIFAR-10和CIFAR-100數據集上的8xv100 GPU機器測試了代碼,同時將批次大小從512降低到128,並從0.003降低到0.001。儘管計算的要求較低,但這種設置與比特雜誌相比,相比之下,該設置幾乎相同(請參見下面的預期結果)。
下面,我們提供有關如何優化論文設置的更多建議。
默認的位hyperrule是在Cloud TPU上開發的,並且是渴望記憶的。這主要是由於批量較大的大小(512)和圖像分辨率(高達480x480)。如果您用完了記憶,這裡有一些提示:
bit_hyperrule.py中,我們指定輸入分辨率。通過減少它,可以以準確性為代價節省大量內存和計算。--batch_split選項支持批處理技術(“ Micro Batching”)。例如,使用--batch_split 8運行微調8將內存需求減少為8。 我們驗證了當使用鑽頭時,此存儲庫中的代碼會重現論文的結果。
對於這些常見的基準測試,上述對位的更改( --batch 128 --base_lr 0.001 )導致以下結果非常相似的結果。該表顯示了至少五次運行的最小←中值→最大結果。注意:這不是對框架的比較,而是證據表明所有代碼基礎都可以信任以復制結果。
| 數據集 | ex/cls | TF2 | JAX | Pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 52.5← 55.8 →60.2 | 48.7← 53.9 →65.0 | 56.4← 56.7 →73.1 |
| CIFAR10 | 5 | 85.3← 87.2 →89.1 | 80.2← 85.8 →88.6 | 84.8← 85.8 →89.6 |
| CIFAR10 | 滿的 | 98.5 | 98.4 | 98.5← 98.6 →98.6 |
| CIFAR100 | 1 | 34.8← 35.7 →37.9 | 32.1← 35.0 →37.1 | 31.6← 33.8 →36.9 |
| CIFAR100 | 5 | 68.8← 70.4 →71.4 | 68.6← 70.8 →71.6 | 70.6← 71.6 →71.7 |
| CIFAR100 | 滿的 | 90.8 | 91.2 | 91.1← 91.2 →91.4 |
| 數據集 | ex/cls | JAX | Pytorch |
|---|---|---|---|
| CIFAR10 | 1 | 44.0← 56.7 →65.0 | 50.9← 55.5 →59.5 |
| CIFAR10 | 5 | 85.3← 87.0 →88.2 | 85.3← 85.8 →88.6 |
| CIFAR10 | 滿的 | 98.5 | 98.5← 98.5 →98.6 |
| CIFAR100 | 1 | 36.4← 37.2 →38.9 | 34.3← 36.8 →39.0 |
| CIFAR100 | 5 | 69.3← 70.5 →72.0 | 70.3← 72.0 →72.3 |
| CIFAR100 | 滿的 | 91.2 | 91.2← 91.3 →91.4 |
(TF2型號尚未可用。)
| 數據集 | ex/cls | TF2 | JAX | Pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 49.9← 54.4 →60.2 | 48.4← 54.1 →66.1 | 45.8← 57.9 →65.7 |
| CIFAR10 | 5 | 80.8← 83.3 →85.5 | 76.7← 82.4 →85.4 | 80.3← 82.3 →84.9 |
| CIFAR10 | 滿的 | 97.2 | 97.3 | 97.4 |
| CIFAR100 | 1 | 35.3← 37.1 →38.2 | 32.0← 35.2 →37.8 | 34.6← 35.2 →38.6 |
| CIFAR100 | 5 | 63.8← 65.0 →66.5 | 63.4← 64.8 →66.5 | 64.7← 65.5 →66.0 |
| CIFAR100 | 滿的 | 86.5 | 86.4 | 86.6 |
這些結果是使用鑽頭獲得的。但是,由於這會導致大批量和大分辨率,因此內存可能是一個問題。 pytorch代碼支持分類的分類,因此,我們仍然可以通過添加--batch_split N命令而在不訴諸雲TPU的情況下運行其中的東西,其中N是兩個的冪。例如,以下命令在具有8 V100 GPU的機器上產生80.68的驗證精度:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
使用4 V100 GPU運行時,進一步增加到--batch_split 8 ,等等。
在某些測試中,完全取得的結果是:
| ex/cls | R50x1 | R152x2 | R101x3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25.55 |
| 5 | 50.64 | 64.5 | 64.18 |
| 滿的 | 80.68 | 85.15 | WIP |
這些是重新運行的,而不是確切的紙質模型。兩個模型的預期VTAB分數是:
| 模型 | 滿的 | 自然的 | 結構 | 專門 |
|---|---|---|---|---|
| BIT-M-R152X4 | 73.51 | 80.77 | 61.08 | 85.67 |
| BIT-M-R101X3 | 72.65 | 80.29 | 59.40 | 85.75 |
在我們論文的附錄G中,我們研究了BIT是否改善了固有性的魯棒性。為此,我們創建了一個數據集,該數據集包含前景對象,該對像對應於21 ILSVRC-2012類粘貼到41個其他背景上。
要下載數據集,請運行
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
來自21個類中每個類別的圖像都保存在一個名稱為“類名稱”的目錄中。
我們在紙上蒸餾中發布了“知識蒸餾:優秀的老師是耐心且一致的”中,我們從論文“知識蒸餾:優秀的老師都是耐心和一致的”中釋放出表現最佳的壓縮位模型。特別是,我們將BIT-M-R152X2模型(在Imagenet-21K上預先訓練)提煉為Bit-R50x1模型。結果,我們獲得了具有競爭性能的緊湊型模型。
| 模型 | 下載鏈接 | 解決 | Imagenet Top-1 Acc。 (紙) |
|---|---|---|---|
| 位r50x1 | 關聯 | 224 | 82.8 |
| 位r50x1 | 關聯 | 160 | 80.5 |
為了獲得可重複性,我們還釋放了兩個BIT-M-R152X2教師模型的權重:在第224號決議和第384號決議上進行了預估計。有關如何使用這些教師的詳細信息,請參見本文。
我們沒有發布蒸餾代碼的具體計劃,因為食譜很簡單,我們想像大多數人會將其集成到現有的培訓代碼中。但是,Sayak Paul已獨立地重新實現了Tensorflow中的蒸餾設置,並在幾種設置中幾乎重現了我們的結果。