Implementasi Pytorch dari Spatial Transformer Network (STN) dengan Thin Plate Spline (TPS).
STN adalah arsitektur jaringan saraf yang kuat yang diusulkan oleh DeepMind di [1]. STN mencapai invarian spasial nyata dengan secara otomatis memperbaiki gambar input sebelum dimasukkan ke dalam jaringan klasifikasi normal. Bagian STN yang paling menakjubkan adalah diferensial ujung ke ujung dan dapat langsung terhubung ke arsitektur jaringan yang ada (Alexnet, ResNet, dll), tanpa pengawasan tambahan.
Kertas STN Asli [1] Eksperimen pada tiga bentuk transformasi spesifik: transformasi affine, transformasi proyektif dan transformasi spline pelat tipis (TPS) . Di antara mereka saya pikir TPS adalah terjemahan yang paling kuat karena dapat melengkungkan gambar dengan cara yang sewenang -wenang. Seperti yang ditunjukkan di bawah ini, saya bisa melengkung avatar saya
ke dalam
TPS-STN telah digunakan dalam aplikasi OCR [2]. Dalam makalah ini TPS-STN adalah untuk secara otomatis memperbaiki gambar teks yang terdistorsi, sebelum dimasukkan ke dalam model pengenalan teks OCR yang normal:
Saya menggunakan imageio untuk membuat visualisasi GIF. Cukup instal dengan pip install imageio .
python mnist_train.py --model unbounded_stn --angle 90 --grid_size 4
python mnist_visualize.py --model unbounded_stn --angle 90 --grid_size 4
python mnist_make_gif.py --model unbounded_stn --angle 90 --grid_size 4
Kemudian PNG dan GIF resutls akan disimpan di ./image/unbounded_stn_angle60_grid4/ dan ./gif/unbounded_stn_angle60_grid4/ .
Anda dapat mencoba kombinasi lain dari arsitektur model, sudut rotasi acak MNIST dan ukuran grid TPS. Detail di bawah ini.
Ada tiga argumen yang dapat dikendalikan: --model , --angle , --grid_size .
--model : str, wajib
no_stn , modul STN dibuang dan hanya satu classifier CNN yang tersisa.bounded_stn , output dari jaringan lokalisasi diperas ke [-1, 1] oleh F.tanh , seperti yang dilakukan pada [2]unbounded_stn , output dari jaringan locolizaition tidak diperas --angle : int, default = 60
[-angle, angle] --grid_size : int, default = 4
(grid_size x grid_size) titik kontrol untuk menentukan transformasi spline pelat tipis Hasil dengan angle = 90 umumnya buruk:
Hasil dengan bounded_stn buruk jika grid_size <= 3 :
Tapi oke jika grid_size >= 4 :
Hasil dengan unbounded_stn baik -baik saja:
Tentu saja selalu ada kasus yang buruk di setiap kombinasi. Anda dapat mengunduh semua gif saya dari Baidu NetDisk (ukuran file 2G).
[1] Jaringan Transformator Spasial
[2] Pengenalan teks adegan yang kuat dengan perbaikan otomatis
[3] 数值方法 —— 薄板样条插值( spline pelat tipis)