ttach
v0.0.3
Augmentasi waktu tes gambar dengan pytorch!
Mirip dengan apa yang dilakukan augmentasi data pada set pelatihan, tujuan augmentasi waktu tes adalah untuk melakukan modifikasi acak pada gambar uji. Jadi, alih -alih menunjukkan gambar reguler, "bersih", hanya sekali untuk model yang terlatih, kami akan menunjukkannya gambar augmented beberapa kali. Kami kemudian akan rata -rata prediksi masing -masing gambar yang sesuai dan menganggapnya sebagai tebakan akhir kami [1].
Input
| # input batch of images
/ / /| # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
/ / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output
import ttach as tta
tta_model = tta . SegmentationTTAWrapper ( model , tta . aliases . d4_transform (), merge_mode = 'mean' ) tta_model = tta . ClassificationTTAWrapper ( model , tta . aliases . five_crop_transform ()) tta_model = tta . KeypointsTTAWrapper ( model , tta . aliases . flip_transform (), scaled = True ) Catatan : Model harus mengembalikan titik tombol dalam format torch([x1, y1, ..., xn, yn])
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta . Compose (
[
tta . HorizontalFlip (),
tta . Rotate90 ( angles = [ 0 , 180 ]),
tta . Scale ( scales = [ 1 , 2 , 4 ]),
tta . Multiply ( factors = [ 0.9 , 1 , 1.1 ]),
]
)
tta_model = tta . SegmentationTTAWrapper ( model , transforms ) # Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
for transformer in transforms : # custom transforms or e.g. tta.aliases.d4_transform()
# augment image
augmented_image = transformer . augment_image ( image )
# pass to model
model_output = model ( augmented_image , another_input_data )
# reverse augmentation for mask and label
deaug_mask = transformer . deaugment_mask ( model_output [ 'mask' ])
deaug_label = transformer . deaugment_label ( model_output [ 'label' ])
# save results
labels . append ( deaug_mask )
masks . append ( deaug_label )
# reduce results as you want, e.g mean/max/min
label = mean ( labels )
mask = mean ( masks )| Mengubah | Parameter | Nilai |
|---|---|---|
| Horizontalflip | - | - |
| Vertikalflip | - | - |
| Rotate90 | sudut | Daftar [0, 90, 180, 270] |
| Skala | timbangan interpolasi | Daftar [float] "terdekat"/"linier" |
| Ubah Ulang | ukuran original_size interpolasi | Daftar [tuple [int, int]] Tuple [int, int] "terdekat"/"linier" |
| Menambahkan | nilai | Daftar [float] |
| Berkembang biak | faktor | Daftar [float] |
| FiveCrops | crop_height crop_width | int int |
PYPI:
$ pip install ttachSumber:
$ pip install git+https://github.com/qubvel/ttachdocker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider