ttach
v0.0.3
Pytorch와의 이미지 테스트 시간 확대!
교육 세트에 대한 데이터 확대가 수행하는 것과 유사하게 테스트 시간 확대의 목적은 테스트 이미지에 대한 무작위 수정을 수행하는 것입니다. 따라서 훈련 된 모델에 한 번만 정기적 인 "깨끗한"이미지를 표시하는 대신, 증강 된 이미지를 여러 번 보여줄 것입니다. 그런 다음 각 해당 이미지의 예측을 평균화하고 최종 추측으로 받아 들일 것입니다 [1].
Input
| # input batch of images
/ / /| # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
/ / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output
import ttach as tta
tta_model = tta . SegmentationTTAWrapper ( model , tta . aliases . d4_transform (), merge_mode = 'mean' ) tta_model = tta . ClassificationTTAWrapper ( model , tta . aliases . five_crop_transform ()) tta_model = tta . KeypointsTTAWrapper ( model , tta . aliases . flip_transform (), scaled = True ) 참고 : 모델은 torch([x1, y1, ..., xn, yn])
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta . Compose (
[
tta . HorizontalFlip (),
tta . Rotate90 ( angles = [ 0 , 180 ]),
tta . Scale ( scales = [ 1 , 2 , 4 ]),
tta . Multiply ( factors = [ 0.9 , 1 , 1.1 ]),
]
)
tta_model = tta . SegmentationTTAWrapper ( model , transforms ) # Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
for transformer in transforms : # custom transforms or e.g. tta.aliases.d4_transform()
# augment image
augmented_image = transformer . augment_image ( image )
# pass to model
model_output = model ( augmented_image , another_input_data )
# reverse augmentation for mask and label
deaug_mask = transformer . deaugment_mask ( model_output [ 'mask' ])
deaug_label = transformer . deaugment_label ( model_output [ 'label' ])
# save results
labels . append ( deaug_mask )
masks . append ( deaug_label )
# reduce results as you want, e.g mean/max/min
label = mean ( labels )
mask = mean ( masks )| 변환 | 매개 변수 | 값 |
|---|---|---|
| HORIZONTALFLIP | - | - |
| verticalflip | - | - |
| 회전 90 | 각도 | 나열 [0, 90, 180, 270] |
| 규모 | 저울 보간 | 목록 [float] "가장 가까운"/"선형" |
| 크기를 조정하십시오 | 크기 original_size 보간 | 목록 [튜플 [int, int]] 튜플 [int, int] "가장 가까운"/"선형" |
| 추가하다 | 값 | 목록 [float] |
| 곱하다 | 요인 | 목록 [float] |
| fivecrops | crop_height crop_width | int int |
pypi :
$ pip install ttach원천:
$ pip install git+https://github.com/qubvel/ttachdocker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider