Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby
업데이트 18/06/2021 : BIT-M-R152X2에서 증류 된 새로운 고성능 BIT-R50X1 모델을 출시합니다.이 섹션을 참조하십시오. 우리 논문의 자세한 내용은 "지식 증류 : 좋은 교사는 참을성 있고 일관성이있다".
업데이트 08/02/2021 : 19 개의 VTAB-1K 데이터 세트에서 모두 미세 조정 된 모든 비트 M 모델도 아래를 참조하십시오.
이 저장소에서 우리는 BIT (Big)에서 여러 모델을 출시합니다 : ILSVRC-2012 및 ImageNet-21K 데이터 세트에서 미리 훈련 된 일반적인 시각적 표현 학습 용지. 우리는 주요 딥 러닝 프레임 워크 Tensorflow 2, Pytorch 및 Jax/Flax에서 릴리스 된 모델을 미세 조정하기위한 코드를 제공합니다.
우리는 컴퓨터 비전 커뮤니티가 ILSVRC-2012 데이터 세트에서 미리 훈련 된 기존 모델과는 반대로보다 강력한 Imagenet-21K 사전 제한 모델을 사용함으로써 혜택을 받기를 바랍니다.
우리는 또한보다 탐색적인 대화식 사용을 위해 Colabs를 제공합니다 : Tensorflow 2 Colab, Pytorch Colab 및 Jax Colab.
컴퓨터에 Python>=3.6 설치되어 있는지 확인하십시오.
Tensorflow 2, Pytorch 또는 Jax를 설정하려면 여기에 연결된 해당 저장소에 제공된 지침을 따르십시오.
또한 실행하여 파이썬 종속성을 설치하십시오 (아래 명령에서 tf2 , pytorch 또는 jax 선택하십시오).
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
먼저 비트 모델을 다운로드하십시오. 우리는 5 가지 아키텍처에 대해 ILSVRC-2012 (BIT-S) 또는 ImageNet-21K (BIT-M)에서 미리 훈련 된 모델을 제공합니다 : RESNET-50X1, RESNET-101X1, RESNET-50X3, RESNET-101X3 및 RESNET-152X4.
예를 들어, Imagenet-21K에서 미리 훈련 된 RESNET-50X1을 다운로드하려면 다음 명령을 실행하십시오.
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
위의 명령에 모델 (비트 -S 또는 비트 -M)과 아키텍처의 이름을 연결하여 다른 모델을 다운로드 할 수 있습니다. npz (Pytorch 및 Jax의 경우)와 h5 (TF2)라는 두 가지 형식의 모델을 제공합니다. 기본적으로 모델 가중치는이 저장소의 루트 폴더에 저장 될 것으로 예상됩니다.
그런 다음 세 가지 프레임 워크 중 하나에 관심있는 데이터 세트에서 다운로드 된 모델을 미세 조정할 수 있습니다. 모든 프레임 워크는 명령 줄 인터페이스를 공유합니다
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
현재. 모든 프레임 워크는 CIFAR-10 및 CIFAR-100 데이터 세트를 자동으로 다운로드합니다. 다른 공개 또는 사용자 정의 데이터 세트는 쉽게 통합 될 수 있습니다. TF2 및 JAX에서는 확장 가능한 TensorFlow 데이터 세트 라이브러리에 의존합니다. Pytorch에서는 Torchvision의 데이터 입력 파이프 라인을 사용합니다.
당사 코드는 미세 조정에 사용 가능한 모든 GPU를 사용합니다.
우리는 또한 낮은 데이터 체제에서의 훈련을 지원합니다. --examples_per_class <K> 옵션은 교육을 위해 클래스 당 k 샘플을 무작위로 그립니다.
사용 가능한 모든 플래그의 자세한 목록을 보려면 python3 -m bit_{pytorch|jax|tf2}.train --help 실행하십시오.
편의를 위해 ILSVRC-2012 데이터 세트에서 이미 미세 조정 된 비트 M 모델을 제공합니다. 모델은 -ILSVRC2012 postfix, 예를 들어 추가하여 다운로드 할 수 있습니다.
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
R50X1, R101X1, R50X3, R101X3, R152X4 : 정확도 또는 속도 중에서 선택할 수 있도록 논문에 언급 된 모든 아키텍처를 해제합니다. 위의 모델 파일 경로에서 선택한 아키텍처로 R50x1 교체하십시오.
우리는 논문 출판 후 더 많은 아키텍처를 조사하고 R152X2가 속도와 정확도 사이에 훌륭한 트레이드 오프를 갖는 것으로 나타 났으므로 릴리스에 이것을 포함시키고 아래 몇 가지 숫자를 제공합니다.
또한 VTAB-1K 벤치 마크에 포함 된 19 개의 작업 각각에 대해 미세 조정 된 모델을 출시합니다. 우리는 각 모델을 세 번 실행하고 각 실행을 릴리스합니다. 이것은 우리가 총 5x19x3 = 285 모델을 출시한다는 것을 의미하며, 전송 학습의 추가 분석에 유용 할 수 있기를 바랍니다.
파일은 다음 패턴을 통해 다운로드 할 수 있습니다.
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
이 모델을 TF2로 변환하지는 않았지만 (따라서 해당 .h5 파일은 없음) TF1 및 TF2에서 사용할 수있는 TFHUB 모델도 업로드했습니다. 그러한 모델을 다운로드하기위한 명령의 예는 다음과 같습니다.
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
재현성을 위해, 우리의 훈련 스크립트는 원본 용지에 사용 된 과하계 (비트 하이퍼 룰)를 사용합니다. 그러나 BIT 모델이 클라우드 TPU 하드웨어를 사용하여 교육 및 양조장이되었으므로 일반적인 GPU 설정의 경우 기본 하이퍼 파라미터가 너무 많은 메모리가 필요하거나 매우 느리게 진행할 수 있습니다. 또한 Bit-Hyperrule은 많은 데이터 세트에 걸쳐 일반화하도록 설계되었으므로 일반적으로보다 효율적인 애플리케이션 별 하이퍼 파라미터를 고안 할 수 있습니다. 따라서, 우리는 사용자가 훨씬 적은 자원이 필요하고 종종 비슷한 정확도를 초래하기 때문에 더 많은 가벼운 설정을 시도하도록 권장합니다.
예를 들어, CIFAR-10 및 CIFAR-100 데이터 세트의 8xv100 GPU 머신을 사용하여 코드를 테스트 한 동시에 배치 크기를 512에서 128로, 학습 속도는 0.003에서 0.001로 줄였습니다. 이 설정은 계산적으로 덜 까다로운 상태에도 불구하고 비트 하이퍼 룰과 비교하여 거의 동일한 성능 (아래 예상 결과 참조)을 초래했습니다.
아래에서는 논문 설정을 최적화하는 방법에 대한 더 많은 제안을 제공합니다.
기본 비트 하이퍼 룰은 클라우드 TPU에서 개발되었으며 메모리 헝가리입니다. 이는 주로 큰 배치 크기 (512)와 이미지 해상도 (최대 480x480) 때문입니다. 메모리가 부족한 경우 몇 가지 팁이 있습니다.
bit_hyperrule.py 에서 입력 해상도를 지정합니다. 그것을 줄임으로써 정확성을 희생시키면서 많은 메모리와 계산을 저장할 수 있습니다.--batch_split 옵션을 통해 배치 분할 기술 ( "Micro-Batching")을 지원합니다. 예를 들어, --batch_split 8 으로 미세 조정을 실행하면 메모리 요구 사항이 8의 계수로 줄어 듭니다. Bit-Hyperrule을 사용할 때이 저장소의 코드가 논문의 결과를 재현하는지 확인했습니다.
이러한 일반적인 벤치 마크의 경우, 비트 하이퍼 룰 ( --batch 128 --base_lr 0.001 )에 대한 위에서 언급 한 변경은 다음과 같은 결과로 이어집니다. 표는 최소 5 번의 실행의 최소 ← 중간 → 최대 결과를 보여줍니다. 참고 : 이것은 프레임 워크의 비교가 아니라 모든 코드베이스가 결과를 재현하기 위해 신뢰할 수 있다는 증거입니다.
| 데이터 세트 | Ex/CLS | TF2 | Jax | Pytorch |
|---|---|---|---|---|
| cifar10 | 1 | 52.5 ← 55.8 → 60.2 | 48.7 ← 53.9 → 65.0 | 56.4 ← 56.7 → 73.1 |
| cifar10 | 5 | 85.3 ← 87.2 → 89.1 | 80.2 ← 85.8 → 88.6 | 84.8 ← 85.8 → 89.6 |
| cifar10 | 가득한 | 98.5 | 98.4 | 98.5 ← 98.6 → 98.6 |
| cifar100 | 1 | 34.8 ← 35.7 → 37.9 | 32.1 ← 35.0 → 37.1 | 31.6 ← 33.8 → 36.9 |
| cifar100 | 5 | 68.8 ← 70.4 → 71.4 | 68.6 ← 70.8 → 71.6 | 70.6 ← 71.6 → 71.7 |
| cifar100 | 가득한 | 90.8 | 91.2 | 91.1 ← 91.2 → 91.4 |
| 데이터 세트 | Ex/CLS | Jax | Pytorch |
|---|---|---|---|
| cifar10 | 1 | 44.0 ← 56.7 → 65.0 | 50.9 ← 55.5 → 59.5 |
| cifar10 | 5 | 85.3 ← 87.0 → 88.2 | 85.3 ← 85.8 → 88.6 |
| cifar10 | 가득한 | 98.5 | 98.5 ← 98.5 → 98.6 |
| cifar100 | 1 | 36.4 ← 37.2 → 38.9 | 34.3 ← 36.8 → 39.0 |
| cifar100 | 5 | 69.3 ← 70.5 → 72.0 | 70.3 ← 72.0 → 72.3 |
| cifar100 | 가득한 | 91.2 | 91.2 ← 91.3 → 91.4 |
(TF2 모델은 아직 사용할 수 없습니다.)
| 데이터 세트 | Ex/CLS | TF2 | Jax | Pytorch |
|---|---|---|---|---|
| cifar10 | 1 | 49.9 ← 54.4 → 60.2 | 48.4 ← 54.1 → 66.1 | 45.8 ← 57.9 → 65.7 |
| cifar10 | 5 | 80.8 ← 83.3 → 85.5 | 76.7 ← 82.4 → 85.4 | 80.3 ← 82.3 → 84.9 |
| cifar10 | 가득한 | 97.2 | 97.3 | 97.4 |
| cifar100 | 1 | 35.3 ← 37.1 → 38.2 | 32.0 ← 35.2 → 37.8 | 34.6 ← 35.2 → 38.6 |
| cifar100 | 5 | 63.8 ← 65.0 → 66.5 | 63.4 ← 64.8 → 66.5 | 64.7 ← 65.5 → 66.0 |
| cifar100 | 가득한 | 86.5 | 86.4 | 86.6 |
이 결과는 비트 하이퍼 룰을 사용하여 얻어졌다. 그러나 이로 인해 큰 배치 크기와 큰 해상도가 발생하기 때문에 메모리가 문제가 될 수 있습니다. Pytorch 코드는 배치 분할을 지원하므로 --batch_split N 명령을 추가하여 클라우드 tpus에 의지하지 않고도 그 물건을 실행할 수 있습니다. 여기서 N 2의 힘입니다. 예를 들어, 다음 명령은 8 v100 gpus가있는 기계에서 80.68 의 검증 정확도를 생성합니다.
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
4 v100 gpus 등으로 실행할 때 --batch_split 8 로 추가로 증가합니다.
일부 테스트 실행에서 그러한 방식으로 달성 된 전체 결과는 다음과 같습니다.
| Ex/CLS | R50X1 | R152X2 | R101X3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25.55 |
| 5 | 50.64 | 64.5 | 64.18 |
| 가득한 | 80.68 | 85.15 | 물티 |
이들은 정확한 종이 모델이 아닙니다. 두 모델에 대한 예상되는 VTAB 점수는 다음과 같습니다.
| 모델 | 가득한 | 자연스러운 | 구조 | 전문 |
|---|---|---|---|---|
| 비트 M-R152X4 | 73.51 | 80.77 | 61.08 | 85.67 |
| 비트 m-r101x3 | 72.65 | 80.29 | 59.40 | 85.75 |
논문의 부록 G에서는 비트가 텍스트 외의 견고성을 향상시키는 지 여부를 조사합니다. 이를 위해 41 개의 기타 배경에 붙여 넣은 21 개의 ILSVRC-2012 클래스에 해당하는 전경 객체로 구성된 데이터 세트를 만들었습니다.
데이터 세트를 다운로드하려면 실행하십시오
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
21 개의 클래스 각각의 이미지는 클래스 이름의 디렉토리에 보관됩니다.
우리는 "지식 증류 : 좋은 교사가 참을성 있고 일관된"논문에서 최고 성능 압축 비트 모델을 방출합니다. 특히, 우리는 BIT-M-R152X2 모델 (ImageNet-21K에서 미리 훈련)을 BIT-R50X1 모델로 증류합니다. 결과적으로, 우리는 매우 경쟁력있는 성능을 가진 소형 모델을 얻습니다.
| 모델 | 링크 다운로드 | 해결 | Imagenet Top-1 Acc. (종이) |
|---|---|---|---|
| 비트 -r50x1 | 링크 | 224 | 82.8 |
| 비트 -r50x1 | 링크 | 160 | 80.5 |
재현성을 위해, 우리는 또한 2 비트 M-R152X2 교사 모델의 가중치를 방출합니다 : 해상도 224 및 해상도 384에서 사전에 사전.
레시피가 간단하고 대부분의 사람들이 기존 교육 코드에 통합 할 것이라고 생각하기 때문에 증류 코드를 게시 할 구체적인 계획이 없습니다. 그러나 Sayak Paul은 독립적으로 증류 설정을 Tensorflow에서 재 구현했으며 여러 설정에서 결과를 거의 재현했습니다.