Александр Колесников, Лукас Бейер, Сяохуа Чжая, Джоан Пуигсервер, Джессика Юнг, Сильвен Гелли, Нил Хоулсби
ОБНОВЛЕНИЕ 18/06/2021: мы выпускаем новые высокопроизводительные модели BIT-R50X1, которые были дистиллированы от BIT-M-R152X2, см. В этом разделе. Более подробная информация в нашей статье «Дистилляция знаний: хороший учитель терпелив и последователен».
Обновление 08/02/2021: Мы также выпускаем все модели BIT-M, настраиваемые на все 19 наборов данных VTAB-1K, см. Ниже.
В этом репозитории мы выпускаем несколько моделей из большой передачи (бит): общего визуального представления учебной бумаги, которые были предварительно обучены на наборах данных ILSVRC-2012 и ImageNet-21K. Мы предоставляем код для точной настройки выпущенных моделей в основных структурах глубокого обучения Tensorflow 2, Pytorch и Jax/лен.
Мы надеемся, что сообщество Computer Vision получит пользу, используя более мощные модели ImageNet-21K, предварительно предопределенные, в отличие от обычных моделей, предварительно обученных на наборе данных ILSVRC-2012.
Мы также предоставляем колола для более исследовательского интерактивного использования: Tensorflow 2 Colab, Pytorch Colab и Jax Colab.
Убедитесь, что у вас есть Python>=3.6 установлен на вашем компьютере.
Чтобы настроить Tensorflow 2, Pytorch или JAX, следуйте инструкциям, представленным в соответствующем хранилище, связанном здесь.
Кроме того, установите зависимости Python с помощью запуска (пожалуйста, выберите tf2 , pytorch или jax в команде ниже):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
Во -первых, загрузите модель бит. Мы предоставляем модели, предварительно обученные на ILSVRC-2012 (BIT-S) или ImageNet-21K (BIT-M) для 5 различных архитектур: RESNET-50x1, RESNET-101X1, RESNET-50x3, RESNET-101X3 и RESNET-152X4.
Например, если вы хотите загрузить Resnet-50x1, предварительно обученный на ImageNet-21K, запустите следующую команду:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
Другие модели могут быть загружены соответственно, подключив имя модели (бит-S или Bit-M) и архитектуры в вышеуказанной команде. Обратите внимание, что мы предоставляем модели в двух форматах: npz (для Pytorch и JAX) и h5 (для TF2). По умолчанию мы ожидаем, что веса модели хранятся в корневой папке этого репозитория.
Затем вы можете запустить точную настройку загруженной модели в вашем наборе данных, представляющего интерес в любой из трех структур. Все фреймворки делятся интерфейсом командной строки
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
В настоящее время. Все фреймворки будут автоматически загружать наборы данных CIFAR-10 и CIFAR-100. Другие общедоступные или пользовательские наборы данных могут быть легко интегрированы: в TF2 и JAX мы полагаемся на расширяемые библиотеки наборов данных TensorFlow. В Pytorch мы используем конвейер ввода данных Torchvision.
Обратите внимание, что наш код использует все доступные графические процессоры для точной настройки.
Мы также поддерживаем обучение в режиме с низкими данными: вариант --examples_per_class <K> будет случайным образом рисовать образцы K на класс для обучения.
Чтобы увидеть подробный список всех доступных флагов, запустите python3 -m bit_{pytorch|jax|tf2}.train --help .
Для удобства мы предоставляем модели BIT-M, которые уже были настраивались в наборе данных ILSVRC-2012. Модели можно загрузить, добавив Postfix -ILSVRC2012 , например,
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
Мы выпускаем все архитектуры, упомянутые в статье, так что вы можете выбрать между точностью или скоростью: R50x1, R101x1, R50x3, R101x3, R152x4. В приведенном выше пути к модельному файлу просто замените R50x1 на выбор по архитектуре.
Мы также исследовали больше архитектур после публикации газеты и обнаружили, что R152x2 имел хороший компромисс между скоростью и точностью, поэтому мы также включим это в релиз и предоставляем несколько цифр ниже.
Мы также выпускаем тонкие модели для каждой из 19 задач, включенных в эталон VTAB-1K. Мы запускали каждую модель три раза и отпускаем каждый из этих прогонов. Это означает, что мы выпускаем в общей сложности 5x19x3 = 285 моделей, и надеемся, что они могут быть полезны при дальнейшем анализе обучения передачи.
Файлы можно загрузить по следующему шаблону:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
Мы не конвертировали эти модели в TF2 (следовательно, нет соответствующего файла .h5 ), однако мы также загрузили модели TFHUB, которые можно использовать в TF1 и TF2. Примером последовательности команд для загрузки одной такой модели является:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
Для воспроизводимости в нашем тренировочном скрипте используются гиперпараметры (бит-гиперрул), которые использовались в оригинальной статье. Обратите внимание, однако, что модели битов были обучены и созданы с использованием облачного оборудования TPU, поэтому для типичной настройки GPU наши гиперпараметры по умолчанию могут потребовать слишком большого количества памяти или привести к очень медленному прогрессу. Более того, битовая гидромассажная промышленность предназначена для обобщения во многих наборах данных, поэтому обычно можно разработать более эффективные гиперпараметры, специфичные для приложения. Таким образом, мы рекомендуем пользователю попробовать больше настройки легкого веса, поскольку он требует гораздо меньше ресурсов и часто приводит к аналогичной точности.
Например, мы протестировали наш код, используя машину GPU 8xv100 на наборах данных CIFAR-10 и CIFAR-100, одновременно уменьшая размер партии с 512 до 128 и скорость обучения с 0,003 до 0,001. Эта настройка привела к почти идентичной производительности (см. Ожидаемые результаты ниже) по сравнению с битовой гидроциклами, несмотря на то, что они были менее требовательными.
Ниже мы предоставляем больше предложений о том, как оптимизировать настройку нашей статьи.
Бит-гиперрула по умолчанию была разработана на облачных TPU и очень жаждает памяти. Это в основном связано с большим размером партии (512) и разрешением изображения (до 480x480). Вот несколько советов, если у вас заканчивается память:
bit_hyperrule.py мы указываем входное разрешение. Сокращая его, можно сохранить много памяти и вычислить за счет точности.--batch_split . Например, запуск тонкой настройки с помощью --batch_split 8 снижает потребность в памяти в 8 раза. Мы подтвердили, что при использовании битовой гиперлулы код в этом репозитории воспроизводит результаты статьи.
Для этих общих контрольных показателей вышеупомянутые изменения в битовой гиперруле ( --batch 128 --base_lr 0.001 ) приводят к следующим, очень похожим результатам. В таблице показан мин ← Медиана → максимальный результат не менее пяти пробежек. Примечание . Это не сравнение структур, просто доказательство того, что все кодовые базы можно доверять для воспроизведения результатов.
| Набор данных | Ex/cls | TF2 | Джакс | Пирог |
|---|---|---|---|---|
| Cifar10 | 1 | 52,5 ← 55,8 → 60,2 | 48,7 ← 53,9 → 65,0 | 56,4 ← 56,7 → 73,1 |
| Cifar10 | 5 | 85,3 ← 87,2 → 89,1 | 80,2 ← 85,8 → 88,6 | 84,8 ← 85,8 → 89,6 |
| Cifar10 | полный | 98.5 | 98.4 | 98,5 ← 98,6 → 98,6 |
| CIFAR100 | 1 | 34,8 ← 35,7 → 37,9 | 32,1 ← 35,0 → 37,1 | 31,6 ← 33,8 → 36,9 |
| CIFAR100 | 5 | 68,8 ← 70,4 → 71,4 | 68,6 ← 70,8 → 71,6 | 70,6 ← 71,6 → 71,7 |
| CIFAR100 | полный | 90.8 | 91.2 | 91.1 ← 91.2 → 91.4 |
| Набор данных | Ex/cls | Джакс | Пирог |
|---|---|---|---|
| Cifar10 | 1 | 44,0 ← 56,7 → 65,0 | 50,9 ← 55,5 → 59,5 |
| Cifar10 | 5 | 85,3 ← 87,0 → 88,2 | 85,3 ← 85,8 → 88,6 |
| Cifar10 | полный | 98.5 | 98,5 ← 98,5 → 98,6 |
| CIFAR100 | 1 | 36,4 ← 37,2 → 38,9 | 34,3 ← 36,8 → 39,0 |
| CIFAR100 | 5 | 69,3 ← 70,5 → 72,0 | 70,3 ← 72,0 → 72,3 |
| CIFAR100 | полный | 91.2 | 91,2 ← 91,3 → 91,4 |
(Модели TF2 еще недоступны.)
| Набор данных | Ex/cls | TF2 | Джакс | Пирог |
|---|---|---|---|---|
| Cifar10 | 1 | 49,9 ← 54,4 → 60,2 | 48,4 ← 54,1 → 66,1 | 45,8 ← 57,9 → 65,7 |
| Cifar10 | 5 | 80,8 ← 83,3 → 85,5 | 76,7 ← 82,4 → 85,4 | 80,3 ← 82,3 → 84,9 |
| Cifar10 | полный | 97.2 | 97.3 | 97.4 |
| CIFAR100 | 1 | 35,3 ← 37,1 → 38,2 | 32,0 ← 35,2 → 37,8 | 34,6 ← 35,2 → 38,6 |
| CIFAR100 | 5 | 63,8 ← 65,0 → 66,5 | 63,4 ← 64,8 → 66,5 | 64,7 ← 65,5 → 66,0 |
| CIFAR100 | полный | 86.5 | 86.4 | 86.6 |
Эти результаты были получены с использованием битовой гиперрулы. Однако, поскольку это приводит к большому размеру партии и большему разрешению, память может быть проблемой. Код Pytorch поддерживает пакетное распределение, и, следовательно, мы все еще можем запускать там вещи, не прибегая к облачным TPU, добавив команду --batch_split N , где N -мощность двух. Например, следующая команда дает точность проверки 80.68 на машине с 8 v100 графическими процессорами:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
Дальнейшее увеличение до --batch_split 8 при работе с 4 v100 графические процессоры и т. Д.
Полные результаты достигнуты таким образом в некоторых тестовых прогонах:
| Ex/cls | R50x1 | R152x2 | R101x3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25,55 |
| 5 | 50.64 | 64,5 | 64.18 |
| полный | 80.68 | 85,15 | Пари |
Это повторные заезды, а не точные бумажные модели. Ожидаемые оценки VTAB для двух моделей:
| Модель | Полный | Естественный | Структурированный | Специализирован |
|---|---|---|---|---|
| BIT-M-R152X4 | 73,51 | 80.77 | 61.08 | 85,67 |
| BIT-M-R101X3 | 72,65 | 80.29 | 59,40 | 85,75 |
В Приложении G о нашей статье мы исследуем, улучшает ли BIT вне контекста надежность. Для этого мы создали набор данных, включающий объекты переднего плана, соответствующие 21 классам ILSVRC-2012, вставленным на 41 Разное фоны.
Чтобы загрузить набор данных, запустите
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
Изображения из каждого из 21 класса хранятся в каталоге с именем класса.
Мы выпускаем высокопроизводительные модели сжатых битов из нашей статьи «Дистилляция знаний: хороший учитель терпелив и последователен» при дистилляции Ноульджа. В частности, мы перегоняем модель BIT-M-R152X2 (которая была предварительно обучена на ImageNet-21K) в модели BIT-R50x1. В результате мы получаем компактные модели с очень конкурентной производительностью.
| Модель | Скачать ссылку | Разрешение | ImageNet Top-1 Acc. (бумага) |
|---|---|---|---|
| BIT-R50x1 | связь | 224 | 82,8 |
| BIT-R50x1 | связь | 160 | 80.5 |
Для воспроизводимости мы также выпускаем веса двух Bit-M-M-R152x2 для учителей: предварительно подготовленные в разрешении 224 и разрешение 384. См. Документ для получения подробной информации о том, как использовались эти учителя.
У нас нет конкретных планов по публикации кода дистилляции, так как рецепт прост, и мы представляем, что большинство людей интегрируют его в существующий код обучения. Тем не менее, Саяк Пол независимо переосмыслил настройку дистилляции в Tensorflow и почти воспроизвел наши результаты в нескольких настройках.