ttach
v0.0.3
Bildtestzeit Augmentation mit Pytorch!
Ähnlich wie bei der Erhöhung der Daten für den Trainingssatz besteht der Zweck der Testzeitvergrößerung darin, zufällige Modifikationen an den Testbildern durchzuführen. Anstatt die regulären, sauberen Bilder zu zeigen, werden wir ihm die erweiterten Bilder mehrmals zeigen. Wir werden dann die Vorhersagen jedes entsprechenden Bildes durchschnitt und dies als endgültige Vermutung [1].
Input
| # input batch of images
/ / /| # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
/ / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output
import ttach as tta
tta_model = tta . SegmentationTTAWrapper ( model , tta . aliases . d4_transform (), merge_mode = 'mean' ) tta_model = tta . ClassificationTTAWrapper ( model , tta . aliases . five_crop_transform ()) tta_model = tta . KeypointsTTAWrapper ( model , tta . aliases . flip_transform (), scaled = True ) HINWEIS : Das Modell muss Tastoint in der Format torch([x1, y1, ..., xn, yn])
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta . Compose (
[
tta . HorizontalFlip (),
tta . Rotate90 ( angles = [ 0 , 180 ]),
tta . Scale ( scales = [ 1 , 2 , 4 ]),
tta . Multiply ( factors = [ 0.9 , 1 , 1.1 ]),
]
)
tta_model = tta . SegmentationTTAWrapper ( model , transforms ) # Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
for transformer in transforms : # custom transforms or e.g. tta.aliases.d4_transform()
# augment image
augmented_image = transformer . augment_image ( image )
# pass to model
model_output = model ( augmented_image , another_input_data )
# reverse augmentation for mask and label
deaug_mask = transformer . deaugment_mask ( model_output [ 'mask' ])
deaug_label = transformer . deaugment_label ( model_output [ 'label' ])
# save results
labels . append ( deaug_mask )
masks . append ( deaug_label )
# reduce results as you want, e.g mean/max/min
label = mean ( labels )
mask = mean ( masks )| Verwandeln | Parameter | Werte |
|---|---|---|
| Horizontalflip | - - | - - |
| Vertikalflip | - - | - - |
| Drehen90 | Winkel | Liste [0, 90, 180, 270] |
| Skala | Waage Interpolation | Liste [Float] "nächstes"/"linear" |
| Größenänderung | Größen original_size Interpolation | Liste [Tuple [int, int]] Tuple [int, int] "nächstes"/"linear" |
| Hinzufügen | Werte | Liste [Float] |
| Multiplizieren | Faktoren | Liste [Float] |
| Fivecrops | Crop_Height Crop_width | int int |
Pypi:
$ pip install ttachQuelle:
$ pip install git+https://github.com/qubvel/ttachdocker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider