โดย Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby
อัปเดต 18/06/2021: เราเปิดตัวรุ่น BIT-R50X1 ที่มีประสิทธิภาพสูงใหม่ซึ่งถูกกลั่นจาก BIT-M-R152X2 ดูส่วนนี้ รายละเอียดเพิ่มเติมในบทความของเรา "การกลั่นความรู้: ครูที่ดีคือผู้ป่วยและสม่ำเสมอ"
อัปเดต 08/02/2021: เรายังเปิดตัวรุ่นบิต M ทั้งหมดปรับแต่งบนชุดข้อมูล VTAB-1K ทั้งหมด 19 ชุดดูด้านล่าง
ในที่เก็บนี้เราปล่อยหลายรุ่นจากการถ่ายโอนขนาดใหญ่ (BIT): กระดาษการเรียนรู้การแสดงภาพทั่วไปที่ได้รับการฝึกอบรมล่วงหน้าบนชุดข้อมูล ILSVRC-2012 และ ImagEnet-21K เราให้รหัสเพื่อปรับแต่งโมเดลที่ปล่อยออกมาในกรอบการเรียนรู้ที่สำคัญที่สำคัญ Tensorflow 2, Pytorch และ Jax/Flax
เราหวังว่าชุมชนวิสัยทัศน์คอมพิวเตอร์จะได้รับประโยชน์จากการใช้โมเดลที่ได้รับการฝึกฝนมาก่อน Imagenet-21K ที่ทรงพลังกว่าเมื่อเทียบกับแบบจำลองทั่วไปที่ผ่านการฝึกอบรมไว้ล่วงหน้าในชุดข้อมูล ILSVRC-2012
นอกจากนี้เรายังให้บริการ colabs สำหรับการใช้งานเชิงโต้ตอบมากขึ้น: tensorflow 2 colab, pytorch colab และ jax colab
ตรวจสอบให้แน่ใจว่าคุณมี Python>=3.6 ติดตั้งบนเครื่องของคุณ
ในการตั้งค่า TensorFlow 2, Pytorch หรือ Jax ให้ทำตามคำแนะนำที่ให้ไว้ในที่เก็บข้อมูลที่เกี่ยวข้องที่เชื่อมโยงที่นี่
นอกจากนี้ติดตั้งการพึ่งพา Python โดยใช้งาน (โปรดเลือก tf2 , pytorch หรือ jax ในคำสั่งด้านล่าง):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
ก่อนอื่นให้ดาวน์โหลดรุ่นบิต เราให้บริการแบบจำลองที่ผ่านการฝึกอบรมล่วงหน้าบน ILSVRC-2012 (BIT-S) หรือ ImagEnet-21K (BIT-M) สำหรับสถาปัตยกรรมที่แตกต่างกัน 5 รายการ: RESNET-50X1, RESNET-101X1, RESNET-50X3, RESNET-101X3 และ RESNET-152X4
ตัวอย่างเช่นหากคุณต้องการดาวน์โหลด RESNET-50X1 ที่ผ่านการฝึกอบรมล่วงหน้าบน ImageNet-21K ให้เรียกใช้คำสั่งต่อไปนี้:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
รุ่นอื่น ๆ สามารถดาวน์โหลดได้โดยการเสียบชื่อของรุ่น (bit-s หรือ bit-m) และสถาปัตยกรรมในคำสั่งด้านบน โปรดทราบว่าเราให้แบบจำลองในสองรูปแบบ: npz (สำหรับ Pytorch และ Jax) และ h5 (สำหรับ TF2) โดยค่าเริ่มต้นเราคาดหวังว่าน้ำหนักรุ่นจะถูกเก็บไว้ในโฟลเดอร์รูทของที่เก็บนี้
จากนั้นคุณสามารถใช้การปรับแต่งรุ่นที่ดาวน์โหลดได้อย่างละเอียดในชุดข้อมูลที่คุณสนใจในสามเฟรมเวิร์กใด ๆ เฟรมเวิร์กทั้งหมดแชร์อินเทอร์เฟซบรรทัดคำสั่ง
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
ตอนนี้. เฟรมเวิร์กทั้งหมดจะดาวน์โหลดชุดข้อมูล CIFAR-10 และ CIFAR-100 โดยอัตโนมัติ ชุดข้อมูลสาธารณะหรือแบบกำหนดเองอื่น ๆ สามารถรวมเข้าด้วยกันได้อย่างง่ายดาย: ใน TF2 และ JAX เราพึ่งพาไลบรารีชุดข้อมูล TensorFlow ที่ขยายได้ ใน Pytorch เราใช้ท่อส่งข้อมูลของ Torchvision
โปรดทราบว่ารหัสของเราใช้ GPU ที่มีอยู่ทั้งหมดสำหรับการปรับแต่ง
นอกจากนี้เรายังสนับสนุนการฝึกอบรมในระบอบการปกครองต่ำ: ตัวเลือก --examples_per_class <K> จะสุ่มเลือกตัวอย่าง k ต่อชั้นเรียนสำหรับการฝึกอบรม
หากต้องการดูรายการโดยละเอียดของธงที่มีอยู่ทั้งหมดให้เรียกใช้ python3 -m bit_{pytorch|jax|tf2}.train --help
เพื่อความสะดวกเรามีรุ่น Bit-M ที่ได้รับการปรับแต่งแล้วในชุดข้อมูล ILSVRC-2012 โมเดลสามารถดาวน์โหลดได้โดยการเพิ่ม -ILSVRC2012 postfix เช่น
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
เราปล่อยสถาปัตยกรรมทั้งหมดที่กล่าวถึงในกระดาษเพื่อให้คุณสามารถเลือกระหว่างความแม่นยำหรือความเร็ว: R50x1, R101x1, R50x3, R101x3, R152x4 ในพา ธ ด้านบนไปยังไฟล์รุ่นเพียงแทนที่ R50x1 ด้วยสถาปัตยกรรมที่คุณเลือก
เราตรวจสอบสถาปัตยกรรมเพิ่มเติมหลังจากการตีพิมพ์ของกระดาษและพบว่า R152x2 มีการแลกเปลี่ยนที่ดีระหว่างความเร็วและความแม่นยำดังนั้นเราจึงรวมสิ่งนี้ไว้ในการเปิดตัวและให้ตัวเลขด้านล่าง
นอกจากนี้เรายังเปิดตัวโมเดลที่ปรับแต่งอย่างละเอียดสำหรับแต่ละงาน 19 งานที่รวมอยู่ในเกณฑ์มาตรฐาน VTAB-1K เราวิ่งแต่ละรุ่นสามครั้งและปล่อยแต่ละการวิ่งเหล่านี้ ซึ่งหมายความว่าเราปล่อยโมเดลทั้งหมด 5x19x3 = 285 และหวังว่าสิ่งเหล่านี้จะเป็นประโยชน์ในการวิเคราะห์การเรียนรู้การถ่ายโอนเพิ่มเติม
ไฟล์สามารถดาวน์โหลดได้ผ่านรูปแบบต่อไปนี้:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
เราไม่ได้แปลงโมเดลเหล่านี้เป็น TF2 (ดังนั้นจึงไม่มีไฟล์ .h5 ที่สอดคล้องกัน) อย่างไรก็ตามเรายังอัปโหลดโมเดล TFHUB ซึ่งสามารถใช้ใน TF1 และ TF2 ได้ ลำดับตัวอย่างของคำสั่งสำหรับการดาวน์โหลดหนึ่งโมเดลดังกล่าวคือ:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
สำหรับการทำซ้ำสคริปต์การฝึกอบรมของเราใช้พารามิเตอร์ไฮเปอร์ (บิต-ไฮเปอร์) ที่ใช้ในกระดาษต้นฉบับ อย่างไรก็ตามโปรดทราบว่าโมเดลบิตได้รับการฝึกฝนและ finetuned โดยใช้ฮาร์ดแวร์คลาวด์ TPU ดังนั้นสำหรับการตั้งค่า GPU ทั่วไปของพารามิเตอร์ไฮเปอร์เริ่มต้นของเราอาจต้องใช้หน่วยความจำมากเกินไปหรือส่งผลให้ความคืบหน้าช้ามาก ยิ่งไปกว่านั้น Bit-HyperRule ได้รับการออกแบบมาเพื่อทั่วไปในชุดข้อมูลหลายชุดดังนั้นโดยทั่วไปแล้วจึงเป็นไปได้ที่จะกำหนดพารามิเตอร์ไฮเปอร์พารามิเตอร์เฉพาะแอปพลิเคชันที่มีประสิทธิภาพมากขึ้น ดังนั้นเราขอแนะนำให้ผู้ใช้ลองการตั้งค่าน้ำหนักเบามากขึ้นเนื่องจากพวกเขาต้องการทรัพยากรน้อยลงและมักจะส่งผลให้เกิดความแม่นยำคล้ายกัน
ตัวอย่างเช่นเราทดสอบรหัสของเราโดยใช้เครื่อง 8xv100 GPU บนชุดข้อมูล CIFAR-10 และ CIFAR-100 ในขณะที่ลดขนาดแบทช์จาก 512 เป็น 128 และอัตราการเรียนรู้จาก 0.003 เป็น 0.001 การตั้งค่านี้ส่งผลให้ประสิทธิภาพเกือบเหมือนกัน (ดูผลลัพธ์ที่คาดหวังด้านล่าง) เมื่อเปรียบเทียบกับบิต-ไฮเปอร์รูลแม้จะมีความต้องการการคำนวณน้อยลง
ด้านล่างนี้เราให้คำแนะนำเพิ่มเติมเกี่ยวกับวิธีเพิ่มประสิทธิภาพการตั้งค่ากระดาษของเรา
Bit-HyperRule เริ่มต้นได้รับการพัฒนาบนคลาวด์ TPUs และค่อนข้างหิวโหย ส่วนใหญ่เป็นเพราะขนาดแบทช์ขนาดใหญ่ (512) และความละเอียดของภาพ (สูงถึง 480x480) นี่คือเคล็ดลับบางประการหากคุณหมดหน่วยความจำ:
bit_hyperrule.py เราระบุความละเอียดอินพุต ด้วยการลดลงเราสามารถบันทึกหน่วยความจำและการคำนวณจำนวนมากได้ด้วยค่าใช้จ่ายของความแม่นยำ--batch_split ตัวเลือก ตัวอย่างเช่นการรันการปรับแต่งด้วย --batch_split 8 ลดความต้องการหน่วยความจำโดยปัจจัย 8 เราตรวจสอบว่าเมื่อใช้ bit-hyperrule รหัสในที่เก็บนี้จะทำซ้ำผลลัพธ์ของกระดาษ
สำหรับเกณฑ์มาตรฐานทั่วไปเหล่านี้การเปลี่ยนแปลงดังกล่าวไปยังบิต-เส้นผม ( --batch 128 --base_lr 0.001 ) นำไปสู่ผลลัพธ์ต่อไปนี้ผลลัพธ์ที่คล้ายกันมาก ตารางแสดงค่าเฉลี่ย← ค่ามัธยฐาน →สูงสุดของการวิ่งอย่างน้อยห้าครั้ง หมายเหตุ : นี่ไม่ใช่การเปรียบเทียบเฟรมเวิร์กเพียงหลักฐานว่าฐานรหัสทั้งหมดสามารถเชื่อถือได้ในการทำซ้ำผลลัพธ์
| ชุดข้อมูล | อดีต/CLS | TF2 | คนขี้ขลาด | pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 52.5 ← 55.8 → 60.2 | 48.7 ← 53.9 → 65.0 | 56.4 ← 56.7 → 73.1 |
| CIFAR10 | 5 | 85.3 ← 87.2 → 89.1 | 80.2 ← 85.8 → 88.6 | 84.8 ← 85.8 → 89.6 |
| CIFAR10 | เต็ม | 98.5 | 98.4 | 98.5 ← 98.6 → 98.6 |
| CIFAR100 | 1 | 34.8 ← 35.7 → 37.9 | 32.1 ← 35.0 → 37.1 | 31.6 ← 33.8 → 36.9 |
| CIFAR100 | 5 | 68.8 ← 70.4 → 71.4 | 68.6 ← 70.8 → 71.6 | 70.6 ← 71.6 → 71.7 |
| CIFAR100 | เต็ม | 90.8 | 91.2 | 91.1 ← 91.2 → 91.4 |
| ชุดข้อมูล | อดีต/CLS | คนขี้ขลาด | pytorch |
|---|---|---|---|
| CIFAR10 | 1 | 44.0 ← 56.7 → 65.0 | 50.9 ← 55.5 → 59.5 |
| CIFAR10 | 5 | 85.3 ← 87.0 → 88.2 | 85.3 ← 85.8 → 88.6 |
| CIFAR10 | เต็ม | 98.5 | 98.5 ← 98.5 → 98.6 |
| CIFAR100 | 1 | 36.4 ← 37.2 → 38.9 | 34.3 ← 36.8 → 39.0 |
| CIFAR100 | 5 | 69.3 ← 70.5 → 72.0 | 70.3 ← 72.0 → 72.3 |
| CIFAR100 | เต็ม | 91.2 | 91.2 ← 91.3 → 91.4 |
(รุ่น TF2 ยังไม่พร้อมใช้งาน)
| ชุดข้อมูล | อดีต/CLS | TF2 | คนขี้ขลาด | pytorch |
|---|---|---|---|---|
| CIFAR10 | 1 | 49.9 ← 54.4 → 60.2 | 48.4 ← 54.1 → 66.1 | 45.8 ← 57.9 → 65.7 |
| CIFAR10 | 5 | 80.8 ← 83.3 → 85.5 | 76.7 ← 82.4 → 85.4 | 80.3 ← 82.3 → 84.9 |
| CIFAR10 | เต็ม | 97.2 | 97.3 | 97.4 |
| CIFAR100 | 1 | 35.3 ← 37.1 → 38.2 | 32.0 ← 35.2 → 37.8 | 34.6 ← 35.2 → 38.6 |
| CIFAR100 | 5 | 63.8 ← 65.0 → 66.5 | 63.4 ← 64.8 → 66.5 | 64.7 ← 65.5 → 66.0 |
| CIFAR100 | เต็ม | 86.5 | 86.4 | 86.6 |
ผลลัพธ์เหล่านี้ได้รับโดยใช้บิต-ไฮเปอร์ อย่างไรก็ตามเนื่องจากสิ่งนี้ส่งผลให้เกิดความละเอียดขนาดใหญ่และความละเอียดขนาดใหญ่หน่วยความจำจึงเป็นปัญหา รหัส pytorch รองรับการแยกแบทช์และด้วยเหตุนี้เรายังสามารถเรียกใช้สิ่งที่อยู่ที่นั่นได้โดยไม่ต้องหันไปใช้คลาวด์ TPUs โดยการเพิ่มคำสั่ง --batch_split N โดยที่ N คือพลังของสอง ตัวอย่างเช่นคำสั่งต่อไปนี้สร้างความแม่นยำในการตรวจสอบความถูกต้องของ 80.68 บนเครื่องที่มี 8 V100 GPU:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
เพิ่มขึ้นเป็น --batch_split 8 เมื่อทำงานด้วย 4 V100 GPU ฯลฯ
ผลลัพธ์เต็มรูปแบบที่ประสบความสำเร็จในการทดสอบบางอย่างคือ:
| อดีต/CLS | R50X1 | R152x2 | R101x3 |
|---|---|---|---|
| 1 | 18.36 | 24.5 | 25.55 |
| 5 | 50.64 | 64.5 | 64.18 |
| เต็ม | 80.68 | 85.15 | เช็ด |
สิ่งเหล่านี้เป็นการเรียกใช้ใหม่และไม่ใช่แบบจำลองกระดาษที่แน่นอน คะแนน VTAB ที่คาดหวังสำหรับสองรุ่นคือ:
| แบบอย่าง | เต็ม | เป็นธรรมชาติ | ที่มีโครงสร้าง | ที่มีความเชี่ยวชาญ |
|---|---|---|---|---|
| Bit-M-R152x4 | 73.51 | 80.77 | 61.08 | 85.67 |
| Bit-M-R101x3 | 72.65 | 80.29 | 59.40 | 85.75 |
ในภาคผนวก G ของกระดาษของเราเราตรวจสอบว่าบิตปรับปรุงความทนทานนอกบริบทหรือไม่ ในการทำเช่นนี้เราได้สร้างชุดข้อมูลที่ประกอบด้วยวัตถุเบื้องหน้าที่สอดคล้องกับ 21 ILSVRC-2012 คลาสที่วางลงบนพื้นหลังเบ็ดเตล็ด 41
หากต้องการดาวน์โหลดชุดข้อมูลให้เรียกใช้
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
รูปภาพจากแต่ละคลาส 21 คลาสจะถูกเก็บไว้ในไดเรกทอรีที่มีชื่อของชั้นเรียน
เราปล่อยโมเดลบิตบีบอัดที่มีประสิทธิภาพสูงสุดจากกระดาษ "การกลั่นความรู้: ครูที่ดีมีความอดทนและสม่ำเสมอ" ในการกลั่นความรู้ โดยเฉพาะอย่างยิ่งเรากลั่นโมเดล Bit-M-R152x2 (ซึ่งได้รับการฝึกอบรมล่วงหน้าเกี่ยวกับ ImageNet-21K) เป็นรุ่น Bit-R50x1 เป็นผลให้เราได้รับโมเดลขนาดกะทัดรัดที่มีประสิทธิภาพการแข่งขันมาก
| แบบอย่าง | ลิงค์ดาวน์โหลด | ปณิธาน | ImageNet Top-1 ACC (กระดาษ) |
|---|---|---|---|
| Bit-R50x1 | การเชื่อมโยง | 224 | 82.8 |
| Bit-R50x1 | การเชื่อมโยง | 160 | 80.5 |
สำหรับการทำซ้ำเรายังปล่อยน้ำหนักของสอง Bit-M-R152x2 รุ่นครู: ได้รับการแก้ไขที่ความละเอียด 224 และความละเอียด 384 ดูกระดาษเพื่อดูรายละเอียดเกี่ยวกับวิธีการใช้ครูเหล่านี้
เราไม่มีแผนการที่เป็นรูปธรรมสำหรับการเผยแพร่รหัสการกลั่นเนื่องจากสูตรนั้นง่ายและเราจินตนาการว่าคนส่วนใหญ่จะรวมเข้ากับรหัสฝึกอบรมที่มีอยู่ อย่างไรก็ตาม Sayak Paul ได้ทำการตั้งค่าการกลั่นใน TensorFlow อีกครั้งและเกือบจะทำซ้ำผลลัพธ์ของเราในการตั้งค่าหลายครั้ง