ttach
v0.0.3
การเพิ่มเวลาทดสอบภาพด้วย pytorch!
เช่นเดียวกับการเพิ่มข้อมูลที่ทำกับชุดการฝึกอบรมวัตถุประสงค์ของการเพิ่มเวลาทดสอบคือการปรับเปลี่ยนแบบสุ่มไปยังภาพทดสอบ ดังนั้นแทนที่จะแสดงภาพปกติ“ สะอาด” เพียงครั้งเดียวกับโมเดลที่ผ่านการฝึกอบรมเพียงครั้งเดียวเราจะแสดงภาพที่เพิ่มขึ้นหลายครั้ง จากนั้นเราจะเฉลี่ยการคาดการณ์ของแต่ละภาพที่เกี่ยวข้องและใช้เป็นครั้งสุดท้ายที่เราคาดเดา [1]
Input
| # input batch of images
/ / /| # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
/ / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output
import ttach as tta
tta_model = tta . SegmentationTTAWrapper ( model , tta . aliases . d4_transform (), merge_mode = 'mean' ) tta_model = tta . ClassificationTTAWrapper ( model , tta . aliases . five_crop_transform ()) tta_model = tta . KeypointsTTAWrapper ( model , tta . aliases . flip_transform (), scaled = True ) หมายเหตุ : รุ่นจะต้องส่งคืนจุดเริ่มต้นในรูปแบบ torch([x1, y1, ..., xn, yn])
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta . Compose (
[
tta . HorizontalFlip (),
tta . Rotate90 ( angles = [ 0 , 180 ]),
tta . Scale ( scales = [ 1 , 2 , 4 ]),
tta . Multiply ( factors = [ 0.9 , 1 , 1.1 ]),
]
)
tta_model = tta . SegmentationTTAWrapper ( model , transforms ) # Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
for transformer in transforms : # custom transforms or e.g. tta.aliases.d4_transform()
# augment image
augmented_image = transformer . augment_image ( image )
# pass to model
model_output = model ( augmented_image , another_input_data )
# reverse augmentation for mask and label
deaug_mask = transformer . deaugment_mask ( model_output [ 'mask' ])
deaug_label = transformer . deaugment_label ( model_output [ 'label' ])
# save results
labels . append ( deaug_mask )
masks . append ( deaug_label )
# reduce results as you want, e.g mean/max/min
label = mean ( labels )
mask = mean ( masks )| เปลี่ยนรูป | พารามิเตอร์ | ค่า |
|---|---|---|
| แนวนอน | - | - |
| แนวดิ่ง | - | - |
| หมุน 90 | มุม | รายการ [0, 90, 180, 270] |
| มาตราส่วน | เครื่องชั่ง การแก้ไข | รายการ [ลอย] "ใกล้"/"linear" |
| ปรับขนาด | ขนาด Original_size การแก้ไข | รายการ [tuple [int, int]] tuple [int, int] "ใกล้"/"linear" |
| เพิ่ม | ค่า | รายการ [ลอย] |
| คูณ | ปัจจัย | รายการ [ลอย] |
| ความร้อนแรง | crop_height crop_width | int int |
pypi:
$ pip install ttachแหล่งที่มา:
$ pip install git+https://github.com/qubvel/ttachdocker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider