Implementação de Pytorch da rede de transformadores espaciais (STN) com spline de placa fina (TPS).
O STN é uma poderosa arquitetura de rede neural proposta por DeepMind em [1]. O STN alcança invariância espacial real, corrigindo automaticamente as imagens de entrada antes que elas sejam alimentadas em uma rede de classificação normal. A parte mais incrível do STN é que ele é diferencial de ponta a ponta e pode ser conectado diretamente às arquiteturas de rede existentes (Alexnet, Resnet etc.), sem qualquer supervisão extra.
Papel STN original [1] Experiências em três formas de transformação específicas: transformação afim, transformação projetiva e transformação de spline de placas finas (TPS) . Entre eles, acho que o TPS é a tradução mais poderosa, porque pode distorcer uma imagem de maneira arbitrária. Como mostrado abaixo, posso deformar meu avatar
em
O TPS-STN foi usado no aplicativo OCR [2]. Neste artigo, o TPS-STN é retificar automaticamente imagens de texto distorcidas, antes de serem alimentadas em um modelo normal de reconhecimento de texto OCR:
Eu uso imageio para criar visualização de GIF. Simplesmente instale -o por pip install imageio .
python mnist_train.py --model unbounded_stn --angle 90 --grid_size 4
python mnist_visualize.py --model unbounded_stn --angle 90 --grid_size 4
python mnist_make_gif.py --model unbounded_stn --angle 90 --grid_size 4
Em seguida, os resutls png e gif serão salvos em ./image/unbounded_stn_angle60_grid4/ e ./gif/unbounded_stn_angle60_grid4/ .
Você pode tentar outras combinações de arquitetura do modelo, ângulo de rotação aleatório do mnist e tamanho da grade TPS. Detalhes abaixo.
Existem três argumentos controláveis: --model , --angle , --grid_size .
--model : STR, necessário
no_stn , o módulo STN é descartado e apenas um único classificador da CNN permanece.bounded_stn , a saída da rede de localização é espremida para [-1, 1] por F.tanh , como foi feito em [2]unbounded_stn , a saída da rede de locolizaition não é espremida --angle : int, padrão = 60
[-angle, angle] --grid_size : int, padrão = 4
(grid_size x grid_size) pontos de controle para definir a transformação de spline de placa fina Os resultados com angle = 90 geralmente são ruins:
Resultados com bounded_stn são ruins se grid_size <= 3 :
Mas ok se grid_size >= 4 :
Resultados com unbounded_stn estão ok:
É claro que sempre há casos ruins em cada combinação. Você pode baixar todos os meus GIFs do Baidu NetDisk (tamanho do arquivo 2G).
[1] Redes de transformadores espaciais
[2] Reconhecimento robusto de texto da cena com retificação automática
[3] 数值方法 数值方法 数值方法 薄板样条插值 薄板样条插值 (spline de placa fina)