Por Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby
ATUALIZAÇÃO 18/06/2021: Lançamos novos modelos de bit-r50x1 de alto desempenho, que foram destilados do Bit-M-R152X2, consulte esta seção. Mais detalhes em nosso artigo "Destilação do conhecimento: um bom professor é paciente e consistente".
ATUALIZAÇÃO 08/02/2021: Também lançamos todos os modelos de bit-m ajustados em todos os 19 conjuntos de dados VTAB-1K, veja abaixo.
Neste repositório, lançamos vários modelos da grande transferência (bit): papel de aprendizado de representação visual geral que foram treinados nos conjuntos de dados ILSVRC-2012 e ImageNet-21K. Fornecemos o código para ajustar os modelos lançados nas principais estruturas de aprendizado profundo Tensorflow 2, Pytorch e Jax/Linho.
Esperamos que a comunidade de visão computacional se beneficie empregando modelos ImageNet-21K mais poderosos, em oposição aos modelos convencionais pré-treinados no conjunto de dados ILSVRC-2012.
Também fornecemos colabs para um uso interativo mais exploratório: um colab de tensorflow 2, um colab pytorch e um colab Jax.
Verifique se você tem Python>=3.6 instalado em sua máquina.
Para configurar o TensorFlow 2, Pytorch ou Jax, siga as instruções fornecidas no repositório correspondente vinculado aqui.
Além disso, instale as dependências do Python em execução (selecione tf2 , pytorch ou jax no comando abaixo):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
Primeiro, faça o download do modelo de bits. Fornecemos modelos pré-treinados em ILSVRC-2012 (bit-s) ou imagenet-21k (bit-m) para 5 arquiteturas diferentes: resnet-50x1, resnet-101x1, resnet-50x3, resnet-101x3 e resNet-152x4.
Por exemplo, se você deseja baixar o RESNET-50X1 pré-treinado no ImageNet-21K, execute o seguinte comando:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
Outros modelos podem ser baixados de acordo com o nome do nome do modelo (bit-s ou bit-m) e arquitetura no comando acima. Observe que fornecemos modelos em dois formatos: npz (para Pytorch e Jax) e h5 (para TF2). Por padrão, esperamos que os pesos do modelo sejam armazenados na pasta raiz deste repositório.
Em seguida, você pode executar o ajuste fino do modelo baixado no seu conjunto de dados de interesse em qualquer uma das três estruturas. Todas as estruturas compartilham a interface da linha de comando
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
Atualmente. Todas as estruturas baixarão automaticamente os conjuntos de dados CIFAR-10 e CIFAR-100. Outros conjuntos de dados públicos ou personalizados podem ser facilmente integrados: no TF2 e JAX, confiamos na biblioteca de conjuntos de dados de tensorflow extensível. Em Pytorch, usamos o pipeline de entrada de dados da Torchvision.
Observe que nosso código usa todas as GPUs disponíveis para ajuste fino.
Também apoiamos o treinamento no regime de baixo data: a opção --examples_per_class <K> desenhará aleatoriamente as amostras K por classe para treinamento.
Para ver uma lista detalhada de todos os sinalizadores disponíveis, execute python3 -m bit_{pytorch|jax|tf2}.train --help .
Por conveniência, fornecemos modelos Bit-M que já estavam ajustados no conjunto de dados ILSVRC-2012. Os modelos podem ser baixados adicionando o postfix -ILSVRC2012
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
Lançamos todas as arquiteturas mencionadas no artigo, de modo que você pode escolher entre precisão ou velocidade: R50X1, R101X1, R50X3, R101X3, R152X4. No caminho acima para o arquivo de modelo, basta substituir R50x1 pela sua arquitetura de escolha.
Investigamos ainda mais arquiteturas após a publicação do artigo e descobrimos que o R152X2 tem uma boa troca entre velocidade e precisão; portanto, também incluímos isso no lançamento e fornecemos alguns números abaixo.
Também lançamos os modelos ajustados para cada uma das 19 tarefas incluídas no benchmark VTAB-1K. Executamos cada modelo três vezes e lançamos cada uma dessas corridas. Isso significa que liberamos um total de 5x19x3 = 285 modelos e esperamos que eles possam ser úteis na análise mais aprofundada do aprendizado de transferência.
Os arquivos podem ser baixados através do seguinte padrão:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
Não convertemos esses modelos em TF2 (portanto, não há arquivo .h5 correspondente), no entanto, também carregamos modelos TFHUB que podem ser usados em TF1 e TF2. Uma sequência de comandos de exemplo para baixar um desses modelos é:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
Para reprodutibilidade, nosso script de treinamento usa hiper-parâmetros (bit-hiperrule) que foram usados no papel original. Observe, no entanto, que os modelos de bits foram treinados e finetunados usando o hardware de TPU em nuvem; portanto, para uma configuração típica de GPU, nossos hiper-parâmetros padrão podem exigir muita memória ou resultar em um progresso muito lento. Além disso, o bit-hiperrule foi projetado para generalizar em muitos conjuntos de dados, por isso é normalmente possível criar hiper-parâmetros específicos de aplicativos mais eficientes. Assim, incentivamos o usuário a tentar configurações mais leves, pois exigem muito menos recursos e geralmente resultam em uma precisão semelhante.
Por exemplo, testamos nosso código usando uma máquina de GPU 8xv100 nos conjuntos de dados CIFAR-10 e CIFAR-100, enquanto reduz o tamanho do lote de 512 para 128 e a taxa de aprendizado de 0,003 a 0,001. Essa configuração resultou em desempenho quase idêntico (veja os resultados esperados abaixo) em comparação com o bit-hiperrule, apesar de ser menos exigente computacionalmente.
Abaixo, fornecemos mais sugestões sobre como otimizar a configuração do nosso artigo.
O bit-hiperrule padrão foi desenvolvido no Cloud TPUS e tem muito faminto por memória. Isso se deve principalmente à grande resolução do tamanho do lote (512) e da imagem (até 480x480). Aqui estão algumas dicas se você estiver sem memória:
bit_hyperrule.py especificamos a resolução de entrada. Ao reduzi -lo, pode -se economizar muita memória e calcular, à custa da precisão.--batch_split . Por exemplo, a execução do ajuste fino com --batch_split 8 reduz o requisito de memória por um fator de 8. Verificamos que, ao usar o bit-hiperrule, o código neste repositório reproduz os resultados do artigo.
Para esses benchmarks comuns, as alterações acima mencionadas no bit-hiperrule ( --batch 128 --base_lr 0.001 ) levam aos seguintes resultados muito semelhantes. A tabela mostra o resultado Min ← mediano → máximo de pelo menos cinco corridas. Nota : Esta não é uma comparação das estruturas, apenas evidências de que todas as bases de código podem ser confiáveis para reproduzir resultados.
| Conjunto de dados | Ex/cls | TF2 | Jax | Pytorch |
|---|---|---|---|---|
| Cifar10 | 1 | 52,5 ← 55,8 → 60.2 | 48,7 ← 53,9 → 65,0 | 56,4 ← 56,7 → 73.1 |
| Cifar10 | 5 | 85,3 ← 87,2 → 89.1 | 80,2 ← 85,8 → 88,6 | 84,8 ← 85,8 → 89,6 |
| Cifar10 | completo | 98.5 | 98.4 | 98,5 ← 98,6 → 98,6 |
| Cifar100 | 1 | 34,8 ← 35,7 → 37,9 | 32,1 ← 35.0 → 37.1 | 31,6 ← 33,8 → 36.9 |
| Cifar100 | 5 | 68,8 ← 70,4 → 71.4 | 68,6 ← 70,8 → 71.6 | 70,6 ← 71,6 → 71.7 |
| Cifar100 | completo | 90.8 | 91.2 | 91.1 ← 91,2 → 91.4 |
| Conjunto de dados | Ex/cls | Jax | Pytorch |
|---|---|---|---|
| Cifar10 | 1 | 44,0 ← 56,7 → 65,0 | 50,9 ← 55,5 → 59,5 |
| Cifar10 | 5 | 85,3 ← 87,0 → 88.2 | 85,3 ← 85,8 → 88,6 |
| Cifar10 | completo | 98.5 | 98,5 ← 98,5 → 98,6 |
| Cifar100 | 1 | 36,4 ← 37,2 → 38.9 | 34,3 ← 36,8 → 39,0 |
| Cifar100 | 5 | 69,3 ← 70,5 → 72,0 | 70,3 ← 72,0 → 72.3 |
| Cifar100 | completo | 91.2 | 91,2 ← 91,3 → 91.4 |
(Os modelos TF2 ainda não estão disponíveis.)
| Conjunto de dados | Ex/cls | TF2 | Jax | Pytorch |
|---|---|---|---|---|
| Cifar10 | 1 | 49,9 ← 54,4 → 60.2 | 48,4 ← 54.1 → 66.1 | 45,8 ← 57,9 → 65.7 |
| Cifar10 | 5 | 80,8 ← 83,3 → 85.5 | 76,7 ← 82,4 → 85.4 | 80,3 ← 82,3 → 84.9 |
| Cifar10 | completo | 97.2 | 97.3 | 97.4 |
| Cifar100 | 1 | 35.3 ← 37,1 → 38.2 | 32,0 ← 35.2 → 37,8 | 34,6 ← 35.2 → 38,6 |
| Cifar100 | 5 | 63,8 ← 65,0 → 66,5 | 63,4 ← 64,8 → 66,5 | 64,7 ← 65,5 → 66.0 |
| Cifar100 | completo | 86.5 | 86.4 | 86.6 |
Esses resultados foram obtidos usando bit-hiperrule. No entanto, como isso resulta em grande e grande resolução em lote e uma grande resolução, a memória pode ser um problema. O código Pytorch suporta a divisão de lote e, portanto, ainda podemos executar as coisas lá sem recorrer ao Cloud TPUs, adicionando o comando --batch_split N onde N é uma potência de dois. Por exemplo, o comando a seguir produz uma precisão de validação de 80.68 em uma máquina com 8 GPUs V100:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
Aumentar ainda mais para --batch_split 8 ao executar com 4 GPUs V100, etc.
Os resultados completos alcançados dessa maneira em algumas execuções de teste foram:
| Ex/cls | R50X1 | R152X2 | R101X3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25.55 |
| 5 | 50.64 | 64.5 | 64.18 |
| completo | 80,68 | 85.15 | WIP |
Estes são reencontros e não os modelos de papel exatos. As pontuações esperadas do VTAB para dois dos modelos são:
| Modelo | Completo | Natural | Estruturado | Especializado |
|---|---|---|---|---|
| Bit-M-R152X4 | 73.51 | 80,77 | 61.08 | 85.67 |
| Bit-M-R101x3 | 72.65 | 80.29 | 59.40 | 85.75 |
No Apêndice G do nosso artigo, investigamos se o bit melhora a robustez fora de contexto. Para fazer isso, criamos um conjunto de dados compreendendo objetos de primeiro plano correspondentes a 21 classes ILSVRC-2012 coladas em 41 antecedentes diversos.
Para baixar o conjunto de dados, execute
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
Imagens de cada uma das 21 classes são mantidas em um diretório com o nome da classe.
Lançamos modelos de bits compactados com melhor desempenho da nossa "destilação de conhecimento: um bom professor é paciente e consistente" na destilação de Knoweldge. Em particular, destilarmos o modelo Bit-M-R152X2 (que foi pré-treinado nos modelos ImageNet-21K) aos bit-r50x1. Como resultado, obtemos modelos compactos com desempenho muito competitivo.
| Modelo | Baixar link | Resolução | ImageNet Top-1 acc. (papel) |
|---|---|---|---|
| Bit-r50x1 | link | 224 | 82.8 |
| Bit-r50x1 | link | 160 | 80.5 |
Para reprodutibilidade, também lançamos pesos de dois modelos de professores Bit-M-R152X2: pré-criados na Resolução 224 e na Resolução 384. Veja o artigo para obter detalhes sobre como esses professores foram usados.
Não temos planos concretos para publicar o código de destilação, pois a receita é simples e imaginamos que a maioria das pessoas o integraria em seu código de treinamento existente. No entanto, Sayak Paul reimplementou independentemente a configuração de destilação no Tensorflow e quase reproduziu nossos resultados em vários ambientes.