ttach
v0.0.3
¡Aumento del tiempo de prueba de imagen con Pytorch!
Similar a lo que el aumento de datos está haciendo al conjunto de capacitación, el propósito del aumento del tiempo de prueba es realizar modificaciones aleatorias a las imágenes de prueba. Por lo tanto, en lugar de mostrar las imágenes regulares "limpias", solo una vez al modelo entrenado, le mostraremos las imágenes aumentadas varias veces. Luego promediaremos las predicciones de cada imagen correspondiente y la tomaremos como nuestra suposición final [1].
Input
| # input batch of images
/ / /| # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
/ / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output
import ttach as tta
tta_model = tta . SegmentationTTAWrapper ( model , tta . aliases . d4_transform (), merge_mode = 'mean' ) tta_model = tta . ClassificationTTAWrapper ( model , tta . aliases . five_crop_transform ()) tta_model = tta . KeypointsTTAWrapper ( model , tta . aliases . flip_transform (), scaled = True ) Nota : El modelo debe devolver puntos clave en la torch([x1, y1, ..., xn, yn])
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta . Compose (
[
tta . HorizontalFlip (),
tta . Rotate90 ( angles = [ 0 , 180 ]),
tta . Scale ( scales = [ 1 , 2 , 4 ]),
tta . Multiply ( factors = [ 0.9 , 1 , 1.1 ]),
]
)
tta_model = tta . SegmentationTTAWrapper ( model , transforms ) # Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
for transformer in transforms : # custom transforms or e.g. tta.aliases.d4_transform()
# augment image
augmented_image = transformer . augment_image ( image )
# pass to model
model_output = model ( augmented_image , another_input_data )
# reverse augmentation for mask and label
deaug_mask = transformer . deaugment_mask ( model_output [ 'mask' ])
deaug_label = transformer . deaugment_label ( model_output [ 'label' ])
# save results
labels . append ( deaug_mask )
masks . append ( deaug_label )
# reduce results as you want, e.g mean/max/min
label = mean ( labels )
mask = mean ( masks )| Transformar | Parámetros | Valores |
|---|---|---|
| Horizontalflip | - | - |
| Vertical | - | - |
| Rotar90 | anglos | Lista [0, 90, 180, 270] |
| Escala | balanza interpolación | Lista [Float] "más cercano"/"lineal" |
| Cambiar de tamaño | tallas Original_size interpolación | Lista [Tuple [int, int]] Tuple [int, int] "más cercano"/"lineal" |
| Agregar | valores | Lista [Float] |
| Multiplicar | factores | Lista [Float] |
| Fivecrops | Crop_Height Crop_width | intencionalmente intencionalmente |
Pypi:
$ pip install ttachFuente:
$ pip install git+https://github.com/qubvel/ttachdocker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider