Alexander Kolesnikov、Lucas Beyer、Xiaohua Zhai、Joan Puigcerver、Jessica Yung、Sylvain Gelly、Neil Houlsby
更新18/06/2021: BIT-M-R152X2から蒸留された新しい高性能BIT-R50x1モデルをリリースします。このセクションを参照してください。私たちの論文の詳細「知識の蒸留:良い教師は忍耐強く一貫性がある」。
更新08/02/2021: 19のVTAB-1Kデータセットすべてで微調整されたすべてのBIT-Mモデルもリリースします。以下を参照してください。
このリポジトリでは、Big Transfer(BIT)から複数のモデルをリリースします。ILSVRC-2012およびImagenet-21Kデータセットで事前に訓練された一般的な視覚表現学習ペーパーです。主要なディープラーニングフレームワークTensorflow 2、PytorchおよびJax/Flaxでリリースされたモデルを微調整するコードを提供します。
ILSVRC-2012データセットで事前に訓練された従来のモデルとは対照的に、より強力なImagENET-21Kの事前処理モデルを採用することにより、コンピュータービジョンコミュニティが利益を得ることを願っています。
また、より探索的なインタラクティブな使用のためのコラブを提供します:Tensorflow 2 Colab、Pytorch Colab、およびJax Colab。
マシンにPython>=3.6がインストールされていることを確認してください。
Tensorflow 2、PytorchまたはJaxをセットアップするには、ここにリンクされている対応するリポジトリに記載されている指示に従ってください。
さらに、実行してPython依存関係をインストールします(以下のコマンドでtf2 、 pytorch 、またはjaxを選択してください):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
まず、ビットモデルをダウンロードします。 5つの異なるアーキテクチャのILSVRC-2012(BIT-S)またはImagENET-21K(BIT-M)で事前トレーニングされたモデルを提供します:ResNet-50x1、ResNet-101x1、ResNet-50x3、ResNet-101X3、およびResNet-152X4。
たとえば、ImagENET-21Kで事前に訓練されたResNet-50x1をダウンロードしたい場合は、次のコマンドを実行します。
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
他のモデルは、モデルの名前(BIT-SまたはBIT-M)と上記のコマンドのアーキテクチャを差し込むことでダウンロードできます。 npz (PytorchおよびJax用)とh5 (TF2用)の2つの形式でモデルを提供することに注意してください。デフォルトでは、モデルの重みがこのリポジトリのルートフォルダーに保存されると予想されます。
次に、3つのフレームワークのいずれかで興味のあるデータセットでダウンロードされたモデルの微調整を実行できます。すべてのフレームワークは、コマンドラインインターフェイスを共有します
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
現在。すべてのフレームワークは、CIFAR-10およびCIFAR-100データセットを自動的にダウンロードします。他のパブリックまたはカスタムデータセットを簡単に統合できます。TF2およびJAXでは、拡張可能なTensorFlow Datasetsライブラリに依存しています。 Pytorchでは、Torchvisionのデータ入力パイプラインを使用しています。
当社のコードは、利用可能なすべてのGPUを微調整に使用していることに注意してください。
また、低データレジームでのトレーニングをサポートします。 --examples_per_class <K>オプションは、トレーニングのためにクラスごとにKサンプルをランダムに描画します。
利用可能なすべてのフラグの詳細なリストを表示するには、 python3 -m bit_{pytorch|jax|tf2}.train --helpを実行します。
便利なため、ILSVRC-2012データセットですでに微調整されているBIT-Mモデルを提供します。モデルは、 -ILSVRC2012 POSTFIXを追加することでダウンロードできます。
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
論文に記載されているすべてのアーキテクチャをリリースするため、精度または速度を選択できます:R50x1、R101x1、R50x3、R101x3、R152x4。モデルファイルへの上記のパスでは、選択したアーキテクチャにR50x1を置き換えるだけです。
さらに、論文の出版後にさらに多くのアーキテクチャを調査し、R152x2が速度と精度の間に素晴らしいトレードオフがあることを発見したため、これをリリースに含め、以下のいくつかの数字を提供します。
また、VTAB-1Kベンチマークに含まれる19のタスクのそれぞれについて、微調整されたモデルをリリースします。各モデルを3回実行し、これらの各実行をリリースします。これは、合計5x19x3 = 285モデルをリリースすることを意味します。これらが、転送学習のさらなる分析に役立つことを願っています。
ファイルは、次のパターンでダウンロードできます。
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
これらのモデルをTF2に変換しませんでした(したがって、対応する.h5ファイルはありません)が、TF1およびTF2で使用できるTFHUBモデルもアップロードしました。そのようなモデルをダウンロードするためのコマンドの例の例は、次のとおりです。
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
再現性のために、当社のトレーニングスクリプトでは、元の論文で使用されたハイパーパラメータ(ビットハイパーール)を使用します。ただし、BITモデルはクラウドTPUハードウェアを使用してトレーニングおよび微調整されているため、典型的なGPUセットアップの場合、デフォルトのハイパーパラメーターは、メモリが多すぎるか、非常に遅い進行をもたらす可能性があります。さらに、Bit-Hyperruleは多くのデータセットで一般化するように設計されているため、通常、より効率的なアプリケーション固有のハイパーパラメーターを考案することが可能です。したがって、ユーザーはより少ないリソースを必要とし、多くの場合同様の精度をもたらすため、より軽量の設定を試すことをお勧めします。
たとえば、CIFAR-10およびCIFAR-100データセットで8xv100 GPUマシンを使用してコードをテストし、512から128にバッチサイズを減らし、学習率は0.003から0.001に減少しました。このセットアップにより、計算上の要求が少ないにもかかわらず、ビットハイパールと比較して、ほぼ同一のパフォーマンス(以下の予想結果を参照)になりました。
以下に、紙のセットアップを最適化する方法についてさらに提案します。
デフォルトのビットハイパールは、クラウドTPUで開発され、非常に記憶に飢えています。これは主に、大きなバッチサイズ(512)と画像解像度(最大480x480)によるものです。メモリが不足している場合は、ここにいくつかのヒントがあります。
bit_hyperrule.pyでは、入力解決を指定します。それを減らすことにより、精度を犠牲にして、多くのメモリを保存して計算できます。--batch_splitオプションを介してバッチ分解技術(「マイクロバッチ」)をサポートしています。たとえば、 --batch_split 8で微調整を実行すると、メモリ要件が8倍になります。 Bit-Hyperruleを使用すると、このリポジトリのコードが論文の結果を再現することを確認しました。
これらの一般的なベンチマークでは、ビットハイパールの前述の変更( --batch 128 --base_lr 0.001 )が次のような結果をもたらします。表は、少なくとも5回の実行のMin←中央値→最大結果を示しています。注:これはフレームワークの比較ではなく、すべてのコードベースが結果を再現するために信頼できるという証拠だけです。
| データセット | ex/cls | TF2 | ジャックス | Pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 52.5← 55.8 →60.2 | 48.7← 53.9 →65.0 | 56.4← 56.7 →73.1 |
| CIFAR10 | 5 | 85.3← 87.2 →89.1 | 80.2← 85.8 →88.6 | 84.8← 85.8 →89.6 |
| CIFAR10 | 満杯 | 98.5 | 98.4 | 98.5← 98.6 →98.6 |
| CIFAR100 | 1 | 34.8← 35.7 →37.9 | 32.1← 35.0 →37.1 | 31.6← 33.8 →36.9 |
| CIFAR100 | 5 | 68.8← 70.4 →71.4 | 68.6← 70.8 →71.6 | 70.6← 71.6 →71.7 |
| CIFAR100 | 満杯 | 90.8 | 91.2 | 91.1← 91.2 →91.4 |
| データセット | ex/cls | ジャックス | Pytorch |
|---|---|---|---|
| CIFAR10 | 1 | 44.0← 56.7 →65.0 | 50.9← 55.5 →59.5 |
| CIFAR10 | 5 | 85.3← 87.0 →88.2 | 85.3← 85.8 →88.6 |
| CIFAR10 | 満杯 | 98.5 | 98.5← 98.5 →98.6 |
| CIFAR100 | 1 | 36.4← 37.2 →38.9 | 34.3← 36.8 →39.0 |
| CIFAR100 | 5 | 69.3← 70.5 →72.0 | 70.3← 72.0 →72.3 |
| CIFAR100 | 満杯 | 91.2 | 91.2← 91.3 →91.4 |
(TF2モデルはまだ利用できません。)
| データセット | ex/cls | TF2 | ジャックス | Pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 49.9← 54.4 →60.2 | 48.4← 54.1 →66.1 | 45.8← 57.9 →65.7 |
| CIFAR10 | 5 | 80.8← 83.3 →85.5 | 76.7← 82.4 →85.4 | 80.3← 82.3 →84.9 |
| CIFAR10 | 満杯 | 97.2 | 97.3 | 97.4 |
| CIFAR100 | 1 | 35.3← 37.1 →38.2 | 32.0← 35.2 →37.8 | 34.6← 35.2 →38.6 |
| CIFAR100 | 5 | 63.8← 65.0 →66.5 | 63.4← 64.8 →66.5 | 64.7← 65.5 →66.0 |
| CIFAR100 | 満杯 | 86.5 | 86.4 | 86.6 |
これらの結果は、Bit-Hyperruleを使用して得られました。ただし、これにより大きなバッチサイズと大規模な解像度が発生するため、メモリが問題になる可能性があります。 Pytorchコードはバッチスプリッティングをサポートするため、 --batch_split Nコマンドを追加することで、クラウドTPUに頼らずに物事を実行できます。ここで、 Nは2のパワーです。たとえば、次のコマンドは、8 V100 GPUを備えたマシンで80.68の検証精度を生成します。
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
4 V100 GPUなどで実行すると、 --batch_split 8にさらに増加します。
いくつかのテスト実行でそのように達成された完全な結果は次のとおりです。
| ex/cls | R50x1 | R152x2 | R101x3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25.55 |
| 5 | 50.64 | 64.5 | 64.18 |
| 満杯 | 80.68 | 85.15 | wip |
これらは再実行であり、正確な紙モデルではありません。 2つのモデルで予想されるVTABスコアは次のとおりです。
| モデル | 満杯 | 自然 | 構造化 | 専門 |
|---|---|---|---|---|
| BIT-M-R152X4 | 73.51 | 80.77 | 61.08 | 85.67 |
| BIT-M-R101X3 | 72.65 | 80.29 | 59.40 | 85.75 |
論文の付録Gでは、BITがコンテキスト外の堅牢性を改善するかどうかを調査します。これを行うために、41のその他の背景に貼り付けられた21のILSVRC-2012クラスに対応する前景オブジェクトを含むデータセットを作成しました。
データセットをダウンロードするには、実行します
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
21のクラスのそれぞれからの画像は、クラスの名前のディレクトリに保持されます。
私たちは、「知識の蒸留:良い教師は忍耐強く一貫性がある」という論文から、最高のパフォーマンスの圧縮ビットモデルをknoweldgeの蒸留についてリリースします。特に、BIT-M-R152X2モデル(ImagENET-21Kで事前に訓練された)をBIT-R50x1モデルに蒸留します。その結果、非常に競争力のあるパフォーマンスを備えたコンパクトモデルを取得します。
| モデル | ダウンロードリンク | 解決 | Imagenet TOP-1 ACC。 (紙) |
|---|---|---|---|
| BIT-R50X1 | リンク | 224 | 82.8 |
| BIT-R50X1 | リンク | 160 | 80.5 |
再現性のために、2つのBIT-M-R152X2教師モデルの重みをリリースします。解像度224と決議384で前提条件。これらの教師の使用方法の詳細については、論文を参照してください。
レシピは簡単であり、ほとんどの人が既存のトレーニングコードに統合すると想像するため、蒸留コードを公開するための具体的な計画はありません。しかし、Sayak PaulはTensorflowの蒸留セットアップを独立して再実装し、いくつかの設定で結果をほぼ再現しました。