Por Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby
ACTUALIZACIÓN 18/06/2021: Lanzamos nuevos modelos BIT-R50X1 de alto rendimiento, que se destilaron de BIT-M-R152x2, consulte esta sección. Más detalles en nuestro artículo "Destilación del conocimiento: un buen maestro es paciente y consistente".
Actualización 08/02/2021: También lanzamos todos los modelos BIT-M ajustados en los 19 conjuntos de datos VTAB-1K, ver a continuación.
En este repositorio lanzamos múltiples modelos de la gran transferencia (bit): papel de aprendizaje de representación visual general que se capacitó en los conjuntos de datos ILSVRC-2012 e ImageNet-21K. Proporcionamos el código para ajustar los modelos lanzados en los principales marcos de aprendizaje profundo Tensorflow 2, Pytorch y Jax/Flax.
Esperamos que la comunidad de visión por computadora se beneficie empleando modelos previos a la petróleo de ImageNet-21K más potentes en lugar de los modelos convencionales previamente entrenados en el conjunto de datos ILSVRC-2012.
También proporcionamos Colabs para un uso interactivo más exploratorio: un tensorflow 2 colab, un pytorch colab y un jax colab.
Asegúrese de tener Python>=3.6 instalado en su máquina.
Para configurar TensorFlow 2, Pytorch o Jax, siga las instrucciones proporcionadas en el repositorio correspondiente vinculado aquí.
Además, instale dependencias de Python ejecutándose (seleccione tf2 , pytorch o jax en el siguiente comando):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
Primero, descargue el modelo Bit. Proporcionamos modelos previamente entrenados en ILSVRC-2012 (BIT-S) o ImageNet-21K (BIT-M) para 5 arquitecturas diferentes: ResNet-50x1, ResNet-101x1, ResNet-50x3, ResNet-101x3 y ResNet-152x4.
Por ejemplo, si desea descargar el resnet-50x1 previamente entrenado en ImageNet-21k, ejecute el siguiente comando:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
Otros modelos se pueden descargar en consecuencia conectando el nombre del modelo (bit-s o bit-m) y arquitectura en el comando anterior. Tenga en cuenta que proporcionamos modelos en dos formatos: npz (para Pytorch y Jax) y h5 (para TF2). Por defecto, esperamos que los pesos del modelo se almacenen en la carpeta raíz de este repositorio.
Luego, puede ejecutar el ajuste del modelo descargado en su conjunto de datos de interés en cualquiera de los tres marcos. Todos los marcos comparten la interfaz de línea de comandos
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
Actualmente. Todos los marcos descargarán automáticamente los conjuntos de datos CIFAR-10 y CIFAR-100. Otros conjuntos de datos públicos o personalizados se pueden integrar fácilmente: en TF2 y Jax confiamos en la Biblioteca extensible de conjuntos de datos TensorFlow. En Pytorch, utilizamos la tubería de entrada de datos de TorchVision.
Tenga en cuenta que nuestro código usa todas las GPU disponibles para ajustar.
También apoyamos la capacitación en el régimen de datos bajos: la opción --examples_per_class <K> dibujará aleatoriamente K muestras por clase para el entrenamiento.
Para ver una lista detallada de todas las banderas disponibles, ejecute python3 -m bit_{pytorch|jax|tf2}.train --help .
Por conveniencia, proporcionamos modelos BIT-M que ya estaban ajustados en el conjunto de datos ILSVRC-2012. Los modelos se pueden descargar agregando el postfix -ILSVRC2012 , por ejemplo
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
Lanzamos todas las arquitecturas mencionadas en el documento, de modo que puede elegir entre precisión o velocidad: R50x1, R101x1, R50x3, R101x3, R152x4. En la ruta anterior al archivo modelo, simplemente reemplace R50x1 por su arquitectura de elección.
Investigamos más a fondo más arquitecturas después de la publicación del periódico y encontramos que R152x2 tiene una buena compensación entre la velocidad y la precisión, por lo tanto, también incluimos esto en el lanzamiento y proporcionamos algunos números a continuación.
También lanzamos los modelos ajustados para cada una de las 19 tareas incluidas en el punto de referencia VTAB-1K. Ejecutamos cada modelo tres veces y lanzamos cada una de estas ejecuciones. Esto significa que liberamos un total de modelos 5x19x3 = 285, y esperamos que estos puedan ser útiles en un análisis posterior del aprendizaje de transferencia.
Los archivos se pueden descargar a través del siguiente patrón:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
No convertimos estos modelos en TF2 (por lo tanto, no hay un archivo .h5 correspondiente), sin embargo, también cargamos modelos TFHUB que se pueden usar en TF1 y TF2. Una secuencia de comandos de ejemplo para descargar uno de esos modelos es:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
Para la reproducibilidad, nuestro script de entrenamiento utiliza hiper-parametros (bit-hiperrule) que se usaron en el documento original. Sin embargo, tenga en cuenta que los modelos de bits fueron entrenados y finetos con hardware de TPU en la nube, por lo que para una configuración típica de GPU, nuestros hiperparametros predeterminados podrían requerir demasiada memoria o dar como resultado un progreso muy lento. Además, Bit-Hyperrule está diseñado para generalizar en muchos conjuntos de datos, por lo que generalmente es posible diseñar hiperparametros más eficientes específicos de la aplicación. Por lo tanto, alentamos al usuario a probar más configuraciones de peso ligero, ya que requieren mucho menos recursos y, a menudo, dan como resultado una precisión similar.
Por ejemplo, probamos nuestro código utilizando una máquina GPU 8XV100 en los conjuntos de datos CIFAR-10 y CIFAR-100, al tiempo que reduce el tamaño de lotes de 512 a 128 y la tasa de aprendizaje de 0.003 a 0.001. Esta configuración resultó en un rendimiento casi idéntico (ver los resultados esperados a continuación) en comparación con Bit-Hyperrule, a pesar de ser menos exigente computacionalmente.
A continuación, proporcionamos más sugerencias sobre cómo optimizar la configuración de nuestro artículo.
El bit-hyperrule predeterminado se desarrolló en TPUS de la nube y tiene bastante hambre de la memoria. Esto se debe principalmente a la gran resolución del tamaño de lotes (512) y la resolución de imágenes (hasta 480x480). Aquí hay algunos consejos si se está quedando sin memoria:
bit_hyperrule.py especificamos la resolución de entrada. Al reducirlo, uno puede ahorrar mucha memoria y calcular, a expensas de la precisión.--batch_split . Por ejemplo, ejecutar el ajuste fino con --batch_split 8 reduce el requisito de memoria en un factor de 8. Verificamos que al usar el bit-hyperrule, el código en este repositorio reproduce los resultados del documento.
Para estos puntos de referencia comunes, los cambios antes mencionados en el bit-hyperrule ( --batch 128 --base_lr 0.001 ) conducen a los siguientes resultados muy similares. La tabla muestra el resultado Min ← Mediana → Max de al menos cinco carreras. Nota : Esta no es una comparación de los marcos, solo evidencia de que se puede confiar en todas las bases de código para reproducir los resultados.
| Conjunto de datos | Ex/CLS | TF2 | Jax | Pytorch |
|---|---|---|---|---|
| Cifar10 | 1 | 52.5 ← 55.8 → 60.2 | 48.7 ← 53.9 → 65.0 | 56.4 ← 56.7 → 73.1 |
| Cifar10 | 5 | 85.3 ← 87.2 → 89.1 | 80.2 ← 85.8 → 88.6 | 84.8 ← 85.8 → 89.6 |
| Cifar10 | lleno | 98.5 | 98.4 | 98.5 ← 98.6 → 98.6 |
| Cifar100 | 1 | 34.8 ← 35.7 → 37.9 | 32.1 ← 35.0 → 37.1 | 31.6 ← 33.8 → 36.9 |
| Cifar100 | 5 | 68.8 ← 70.4 → 71.4 | 68.6 ← 70.8 → 71.6 | 70.6 ← 71.6 → 71.7 |
| Cifar100 | lleno | 90.8 | 91.2 | 91.1 ← 91.2 → 91.4 |
| Conjunto de datos | Ex/CLS | Jax | Pytorch |
|---|---|---|---|
| Cifar10 | 1 | 44.0 ← 56.7 → 65.0 | 50.9 ← 55.5 → 59.5 |
| Cifar10 | 5 | 85.3 ← 87.0 → 88.2 | 85.3 ← 85.8 → 88.6 |
| Cifar10 | lleno | 98.5 | 98.5 ← 98.5 → 98.6 |
| Cifar100 | 1 | 36.4 ← 37.2 → 38.9 | 34.3 ← 36.8 → 39.0 |
| Cifar100 | 5 | 69.3 ← 70.5 → 72.0 | 70.3 ← 72.0 → 72.3 |
| Cifar100 | lleno | 91.2 | 91.2 ← 91.3 → 91.4 |
(Los modelos TF2 aún no están disponibles).
| Conjunto de datos | Ex/CLS | TF2 | Jax | Pytorch |
|---|---|---|---|---|
| Cifar10 | 1 | 49.9 ← 54.4 → 60.2 | 48.4 ← 54.1 → 66.1 | 45.8 ← 57.9 → 65.7 |
| Cifar10 | 5 | 80.8 ← 83.3 → 85.5 | 76.7 ← 82.4 → 85.4 | 80.3 ← 82.3 → 84.9 |
| Cifar10 | lleno | 97.2 | 97.3 | 97.4 |
| Cifar100 | 1 | 35.3 ← 37.1 → 38.2 | 32.0 ← 35.2 → 37.8 | 34.6 ← 35.2 → 38.6 |
| Cifar100 | 5 | 63.8 ← 65.0 → 66.5 | 63.4 ← 64.8 → 66.5 | 64.7 ← 65.5 → 66.0 |
| Cifar100 | lleno | 86.5 | 86.4 | 86.6 |
Estos resultados se obtuvieron usando bit-hyperrule. Sin embargo, debido a que esto da como resultado un gran tamaño por lotes y una gran resolución, la memoria puede ser un problema. El código PyTorch admite la división por lotes y, por lo tanto, aún podemos ejecutar cosas allí sin recurrir a las TPU de la nube agregando el comando --batch_split N donde N es un poder de dos. Por ejemplo, el siguiente comando produce una precisión de validación de 80.68 en una máquina con 8 GPU V100:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
Aumento adicional a --batch_split 8 cuando se ejecuta con 4 GPU de V100, etc.
Los resultados completos logrados de esa manera en algunas pruebas fueron:
| Ex/CLS | R50X1 | R152x2 | R101x3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25.55 |
| 5 | 50.64 | 64.5 | 64.18 |
| lleno | 80.68 | 85.15 | Ceñudo |
Estos son reiniciados y no los modelos de papel exactos. Los puntajes VTAB esperados para dos de los modelos son:
| Modelo | Lleno | Natural | Estructurado | Especializado |
|---|---|---|---|---|
| Bit-m-r152x4 | 73.51 | 80.77 | 61.08 | 85.67 |
| Bit-m-r101x3 | 72.65 | 80.29 | 59.40 | 85.75 |
En el Apéndice G de nuestro artículo, investigamos si Bit mejora la robustez fuera de contexto. Para hacer esto, creamos un conjunto de datos que comprende objetos de primer plano correspondientes a 21 clases ILSVRC-2012 pegadas en 41 fondos diversos.
Para descargar el conjunto de datos, ejecutar
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
Las imágenes de cada una de las 21 clases se mantienen en un directorio con el nombre de la clase.
Liberamos modelos de bits comprimidos de alto rendimiento de nuestro artículo "Destilación del conocimiento: un buen maestro es paciente y consistente" en la destilación KnowEldge. En particular, destilamos el modelo BIT-M-R152X2 (que se entrenó previamente en los modelos ImageNet-21K) a los modelos BIT-R50X1. Como resultado, obtenemos modelos compactos con un rendimiento muy competitivo.
| Modelo | Enlace de descarga | Resolución | Imagenet Top-1 Acc. (papel) |
|---|---|---|---|
| Bit-r50x1 | enlace | 224 | 82.8 |
| Bit-r50x1 | enlace | 160 | 80.5 |
Para la reproducibilidad, también liberamos pesos de dos modelos de maestros BIT-M-R152x2: antes de la resolución 224 y la Resolución 384. Vea el artículo para obtener detalles sobre cómo se usaron estos maestros.
No tenemos planes concretos para publicar el código de destilación, ya que la receta es simple e imaginamos que la mayoría de las personas la integrarían en su código de capacitación existente. Sin embargo, Sayak Paul ha vuelto a implementar de forma independiente la configuración de destilación en TensorFlow y casi reproducido nuestros resultados en varios entornos.