亚历山大·科莱斯尼科夫(Alexander Kolesnikov),卢卡斯·拜尔(Lucas Beyer),Xiaohua Zhai,Joan Puigcerver,Jessica Yung,Sylvain Gelly,Neil Houlsby
更新18/06/2021:我们发布了从BIT-M-R152X2提炼的新的高性能Bit-R50x1型号,请参见本节。我们的论文中的更多细节“知识蒸馏:好的老师是耐心且一致的”。
更新08/02/2021:我们还发布了所有19个VTAB-1K数据集的所有Bit-M模型,请参见下文。
在此存储库中,我们从大型传输(位)中释放多个模型:一般视觉表示学习纸,这些模型已在ILSVRC-2012和Imagenet-21K数据集中进行了预训练。我们提供代码,以微调主要的深度学习框架Tensorflow 2,Pytorch和Jax/Flax。
我们希望计算机视觉社区将通过使用更强大的Imagenet-21 K审慎模型而受益,而不是在ILSVRC-2012数据集中预先培训的常规模型。
我们还提供了可用于更具探索性交互式用途的Colabs:Tensorflow 2 Colab,Pytorch Colab和Jax Colab。
确保在计算机上安装了Python>=3.6 。
要设置TensorFlow 2,Pytorch或Jax,请遵循此处链接的相应存储库中提供的说明。
此外,通过运行安装Python依赖项(请在下面的命令中选择tf2 , pytorch或jax ):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
首先,下载位模型。我们为5种不同架构的ILSVRC-2012(BIT-S)或ImagEnet-21K(BIT-M)提供了预训练的模型:Resnet-50x1,Resnet-101x1,Resnet-50x3,Resnet-101x3,Resnet-101x3和Resnet-152x4。
例如,如果您想在Imagenet-21K上下载预训练的Resnet-50x1,请运行以下命令:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
可以通过在上述命令中插入模型(位-S或BIT-M)和体系结构的名称来相应下载其他模型。请注意,我们提供两种格式的模型: npz (用于Pytorch和Jax)和h5 (对于TF2)。默认情况下,我们希望模型权重存储在此存储库的根文件夹中。
然后,您可以在三个框架中的任何一个中的任何数据集中对下载的模型进行微调。所有框架共享命令行接口
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
现在。所有框架将自动下载CIFAR-10和CIFAR-100数据集。其他公共或自定义数据集可以很容易地集成:在TF2和JAX中,我们依赖于可扩展的Tensorflow数据集库。在Pytorch中,我们使用Torchvision的数据输入管道。
请注意,我们的代码使用所有可用的GPU进行微调。
我们还支持低数据制度中的培训: --examples_per_class <K>选项将随机绘制每个课程的K样本进行培训。
要查看所有可用标志的详细列表,请运行python3 -m bit_{pytorch|jax|tf2}.train --help 。
为了方便起见,我们提供了在ILSVRC-2012数据集上已经进行了微调的位模型。可以通过添加-ILSVRC2012 Postfix下载这些模型
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
我们发布了论文中提到的所有架构,以便您可以在准确性或速度之间进行选择:R50x1,R101x1,R50x3,R101x3,R152x4。在上述模型文件的路径中,只需通过选择的体系结构替换R50x1即可。
本文出版后,我们进一步研究了更多的架构,发现R152x2在速度和准确性之间取决于良好的权衡,因此我们还将其包括在版本中,并在下面提供了一些数字。
我们还为VTAB-1K基准中包含的19个任务中的每一个都发布了微调模型。我们运行了每个型号三次,并释放这些运行中的每一个。这意味着我们总共发布了5x19x3 = 285个模型,并希望这些模型可用于进一步分析转移学习。
可以通过以下模式下载文件:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
我们没有将这些模型转换为TF2(因此没有相应的.h5文件),但是,我们还上传了可以在TF1和TF2中使用的TFHUB模型。下载一个这样一个模型的命令序列是:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
为了获得可重复性,我们的培训脚本使用了原始论文中使用的超参数(钻头)。但是请注意,使用云TPU硬件对BIT模型进行了训练和填充,因此对于典型的GPU设置,我们默认的超参数可能需要过多的内存或导致进度非常缓慢。此外,Bit-Hyperrule旨在概括许多数据集,因此通常可以设计更有效的特定于应用程序的超参数。因此,我们鼓励用户尝试更多的轻重量设置,因为它们需要更少的资源,并且通常会产生类似的精度。
例如,我们使用CIFAR-10和CIFAR-100数据集上的8xv100 GPU机器测试了代码,同时将批次大小从512降低到128,并从0.003降低到0.001。尽管计算的要求较低,但这种设置与比特杂志相比,相比之下,该设置几乎相同(请参见下面的预期结果)。
下面,我们提供有关如何优化论文设置的更多建议。
默认的位hyperrule是在Cloud TPU上开发的,并且是渴望记忆的。这主要是由于批量较大的大小(512)和图像分辨率(高达480x480)。如果您用完了记忆,这里有一些提示:
bit_hyperrule.py中,我们指定输入分辨率。通过减少它,可以以准确性为代价节省大量内存和计算。--batch_split选项支持批处理技术(“ Micro Batching”)。例如,使用--batch_split 8运行微调8将内存需求减少为8。 我们验证了当使用钻头时,此存储库中的代码会重现论文的结果。
对于这些常见的基准测试,上述对位的更改( --batch 128 --base_lr 0.001 )导致以下结果非常相似的结果。该表显示了至少五次运行的最小←中值→最大结果。注意:这不是对框架的比较,而是证据表明所有代码基础都可以信任以复制结果。
| 数据集 | ex/cls | TF2 | JAX | Pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 52.5← 55.8 →60.2 | 48.7← 53.9 →65.0 | 56.4← 56.7 →73.1 |
| CIFAR10 | 5 | 85.3← 87.2 →89.1 | 80.2← 85.8 →88.6 | 84.8← 85.8 →89.6 |
| CIFAR10 | 满的 | 98.5 | 98.4 | 98.5← 98.6 →98.6 |
| CIFAR100 | 1 | 34.8← 35.7 →37.9 | 32.1← 35.0 →37.1 | 31.6← 33.8 →36.9 |
| CIFAR100 | 5 | 68.8← 70.4 →71.4 | 68.6← 70.8 →71.6 | 70.6← 71.6 →71.7 |
| CIFAR100 | 满的 | 90.8 | 91.2 | 91.1← 91.2 →91.4 |
| 数据集 | ex/cls | JAX | Pytorch |
|---|---|---|---|
| CIFAR10 | 1 | 44.0← 56.7 →65.0 | 50.9← 55.5 →59.5 |
| CIFAR10 | 5 | 85.3← 87.0 →88.2 | 85.3← 85.8 →88.6 |
| CIFAR10 | 满的 | 98.5 | 98.5← 98.5 →98.6 |
| CIFAR100 | 1 | 36.4← 37.2 →38.9 | 34.3← 36.8 →39.0 |
| CIFAR100 | 5 | 69.3← 70.5 →72.0 | 70.3← 72.0 →72.3 |
| CIFAR100 | 满的 | 91.2 | 91.2← 91.3 →91.4 |
(TF2型号尚未可用。)
| 数据集 | ex/cls | TF2 | JAX | Pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 49.9← 54.4 →60.2 | 48.4← 54.1 →66.1 | 45.8← 57.9 →65.7 |
| CIFAR10 | 5 | 80.8← 83.3 →85.5 | 76.7← 82.4 →85.4 | 80.3← 82.3 →84.9 |
| CIFAR10 | 满的 | 97.2 | 97.3 | 97.4 |
| CIFAR100 | 1 | 35.3← 37.1 →38.2 | 32.0← 35.2 →37.8 | 34.6← 35.2 →38.6 |
| CIFAR100 | 5 | 63.8← 65.0 →66.5 | 63.4← 64.8 →66.5 | 64.7← 65.5 →66.0 |
| CIFAR100 | 满的 | 86.5 | 86.4 | 86.6 |
这些结果是使用钻头获得的。但是,由于这会导致大批量和大分辨率,因此内存可能是一个问题。 pytorch代码支持分类的分类,因此,我们仍然可以通过添加--batch_split N命令而在不诉诸云TPU的情况下运行其中的东西,其中N是两个的幂。例如,以下命令在具有8 V100 GPU的机器上产生80.68的验证精度:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
使用4 V100 GPU运行时,进一步增加到--batch_split 8 ,等等。
在某些测试中,完全取得的结果是:
| ex/cls | R50x1 | R152x2 | R101x3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25.55 |
| 5 | 50.64 | 64.5 | 64.18 |
| 满的 | 80.68 | 85.15 | WIP |
这些是重新运行的,而不是确切的纸质模型。两个模型的预期VTAB分数是:
| 模型 | 满的 | 自然的 | 结构 | 专门 |
|---|---|---|---|---|
| BIT-M-R152X4 | 73.51 | 80.77 | 61.08 | 85.67 |
| BIT-M-R101X3 | 72.65 | 80.29 | 59.40 | 85.75 |
在我们论文的附录G中,我们研究了BIT是否改善了固有性的鲁棒性。为此,我们创建了一个数据集,该数据集包含前景对象,该对象对应于21 ILSVRC-2012类粘贴到41个其他背景上。
要下载数据集,请运行
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
来自21个类中每个类别的图像都保存在一个名称为“类名称”的目录中。
我们在纸上蒸馏中发布了“知识蒸馏:优秀的老师是耐心且一致的”中,我们从论文“知识蒸馏:优秀的老师都是耐心和一致的”中释放出表现最佳的压缩位模型。特别是,我们将BIT-M-R152X2模型(在Imagenet-21K上预先训练)提炼为Bit-R50x1模型。结果,我们获得了具有竞争性能的紧凑型模型。
| 模型 | 下载链接 | 解决 | Imagenet Top-1 Acc。 (纸) |
|---|---|---|---|
| 位r50x1 | 关联 | 224 | 82.8 |
| 位r50x1 | 关联 | 160 | 80.5 |
为了获得可重复性,我们还释放了两个BIT-M-R152X2教师模型的权重:在第224号决议和第384号决议上进行了预估计。有关如何使用这些教师的详细信息,请参见本文。
我们没有发布蒸馏代码的具体计划,因为食谱很简单,我们想象大多数人会将其集成到现有的培训代码中。但是,Sayak Paul已独立地重新实现了Tensorflow中的蒸馏设置,并在几种设置中几乎重现了我们的结果。