ttach
v0.0.3
Augmentation du temps de test d'image avec Pytorch!
Semblable à ce que l'augmentation des données fait à l'ensemble de formation, le but de l'augmentation du temps de test est d'effectuer des modifications aléatoires aux images de test. Ainsi, au lieu de montrer les images régulières et «propres», une seule fois vers le modèle formé, nous le montrerons plusieurs fois les images augmentées. Nous allons ensuite en moyenne les prédictions de chaque image correspondante et prendre cela comme notre devine finale [1].
Input
| # input batch of images
/ / /| # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
/ / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output
import ttach as tta
tta_model = tta . SegmentationTTAWrapper ( model , tta . aliases . d4_transform (), merge_mode = 'mean' ) tta_model = tta . ClassificationTTAWrapper ( model , tta . aliases . five_crop_transform ()) tta_model = tta . KeypointsTTAWrapper ( model , tta . aliases . flip_transform (), scaled = True ) Remarque : Le modèle doit renvoyer des points clés dans le format torch([x1, y1, ..., xn, yn])
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta . Compose (
[
tta . HorizontalFlip (),
tta . Rotate90 ( angles = [ 0 , 180 ]),
tta . Scale ( scales = [ 1 , 2 , 4 ]),
tta . Multiply ( factors = [ 0.9 , 1 , 1.1 ]),
]
)
tta_model = tta . SegmentationTTAWrapper ( model , transforms ) # Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
for transformer in transforms : # custom transforms or e.g. tta.aliases.d4_transform()
# augment image
augmented_image = transformer . augment_image ( image )
# pass to model
model_output = model ( augmented_image , another_input_data )
# reverse augmentation for mask and label
deaug_mask = transformer . deaugment_mask ( model_output [ 'mask' ])
deaug_label = transformer . deaugment_label ( model_output [ 'label' ])
# save results
labels . append ( deaug_mask )
masks . append ( deaug_label )
# reduce results as you want, e.g mean/max/min
label = mean ( labels )
mask = mean ( masks )| Transformer | Paramètres | Valeurs |
|---|---|---|
| Horizontalflip | - | - |
| Vertical | - | - |
| Rotation90 | angle | Liste [0, 90, 180, 270] |
| Échelle | Balance interpolation | Liste [float] "le plus proche" / "linéaire" |
| Redimensionner | tailles original_size interpolation | List [tuple [int, int]] Tuple [int, int] "le plus proche" / "linéaire" |
| Ajouter | valeurs | Liste [float] |
| Multiplier | facteurs | Liste [float] |
| Fivecrops | Crop_height Crop_width | int int |
PYPI:
$ pip install ttachSource:
$ pip install git+https://github.com/qubvel/ttachdocker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider