การใช้งาน Pytorch Biggan อย่างเป็นทางการของผู้เขียนอย่างเป็นทางการ

repo นี้มีรหัสสำหรับการฝึกอบรม GPU 4-8 ของ Biggans จากการฝึกอบรม Gan ขนาดใหญ่สำหรับการสังเคราะห์ภาพธรรมชาติที่มีความเที่ยงตรงสูงโดย Andrew Brock, Jeff Donahue และ Karen Simonyan
รหัสนี้คือโดย Andy Brock และ Alex Andonian
คุณจะต้อง:
ก่อนอื่นคุณอาจเลือกเตรียมชุดข้อมูลเป้าหมาย HDF5 ที่ผ่านการประมวลผลไว้ล่วงหน้าสำหรับ I/O ที่เร็วขึ้น การติดตามสิ่งนี้ (หรือไม่) คุณจะต้องมีช่วงเวลาเริ่มต้นที่จำเป็นในการคำนวณ FID สิ่งเหล่านี้สามารถทำได้โดยการปรับเปลี่ยนและทำงาน
sh scripts/utils/prepare_data.sh ซึ่งโดยค่าเริ่มต้นจะถือว่าชุดการฝึกอบรม Imagenet ของคุณจะถูกดาวน์โหลดลงใน data โฟลเดอร์รูทในไดเรกทอรีนี้และจะเตรียม HDF5 ที่แคชไว้ที่ความละเอียด 128x128 พิกเซล
ในโฟลเดอร์สคริปต์มีสคริปต์ Bash หลายตัวซึ่งจะฝึกอบรม Biggans ที่มีขนาดแบทช์ที่แตกต่างกัน รหัสนี้จะถือว่าคุณไม่สามารถเข้าถึงฝัก TPU เต็มรูปแบบและการปลอมแปลงแบทช์ขนาดใหญ่โดยใช้การสะสมการไล่ระดับสี โดยค่าเริ่มต้นสคริปต์ launch_BigGAN_bs256x8.sh ฝึกอบรมโมเดล Biggan ขนาดเต็มขนาดเต็มขนาด 256 และ 8 การสะสมการไล่ระดับสีสำหรับขนาดแบทช์ทั้งหมด 2048 ใน 8xv100 พร้อมการฝึกอบรมเต็มความแม่นยำ (ไม่มีแกนเทนเซอร์)
ก่อนอื่นคุณจะต้องหาขนาดแบทช์สูงสุดการตั้งค่าของคุณสามารถรองรับได้ แบบจำลองที่ผ่านการฝึกอบรมมาก่อนที่จะได้รับการฝึกฝนใน 8xv100 (16GB VRAM แต่ละอัน) ซึ่งสามารถรองรับได้มากกว่า BS256 ที่ใช้โดยค่าเริ่มต้นเล็กน้อย เมื่อคุณพิจารณาสิ่งนี้แล้วคุณควรแก้ไขสคริปต์เพื่อให้ขนาดแบทช์คูณจำนวนการสะสมการไล่ระดับสีเท่ากับขนาดแบทช์ทั้งหมดที่คุณต้องการ (ค่าเริ่มต้น Biggan ถึง 2048)
โปรดทราบด้วยว่าสคริปต์นี้ใช้ --load_in_mem arg ซึ่งโหลดไฟล์ทั้งหมด (~ 64GB) i128.hdf5 ทั้งหมดลงใน RAM เพื่อโหลดข้อมูลที่เร็วขึ้น หากคุณมี RAM ไม่เพียงพอที่จะสนับสนุนสิ่งนี้ (อาจเป็น 96GB+) ให้ลบอาร์กิวเมนต์นี้ออก

ในระหว่างการฝึกอบรมสคริปต์นี้จะส่งออกบันทึกด้วยตัวชี้วัดการฝึกอบรมและตัวชี้วัดการทดสอบจะช่วยประหยัดสำเนาหลายชุด (2 ครั้งล่าสุดและ 5 คะแนนสูงสุด) ของพารามิเตอร์น้ำหนัก/เครื่องมือเพิ่มประสิทธิภาพและจะผลิตตัวอย่างและการแก้ไขทุกครั้งที่ประหยัดน้ำหนัก โฟลเดอร์บันทึกมีสคริปต์เพื่อประมวลผลบันทึกเหล่านี้และพล็อตผลลัพธ์โดยใช้ MATLAB (ขออภัยไม่ขอโทษ)
หลังจากการฝึกอบรมเราสามารถใช้ sample.py เพื่อสร้างตัวอย่างเพิ่มเติมและการแก้ไขทดสอบด้วยค่าการตัดทอนที่แตกต่างกันขนาดแบทช์จำนวนการสะสมสถิติยืน ฯลฯ ดูตัวอย่างสคริปต์ sample_BigGAN_bs256x8.sh สำหรับตัวอย่าง
โดยค่าเริ่มต้นทุกอย่างจะถูกบันทึกไว้ในน้ำหนัก/ตัวอย่าง/บันทึก/โฟลเดอร์ข้อมูลซึ่งสันนิษฐานว่าอยู่ในโฟลเดอร์เดียวกับ repo นี้ คุณสามารถชี้สิ่งเหล่านี้ทั้งหมดไปยังโฟลเดอร์พื้นฐานที่แตกต่างกันโดยใช้อาร์กิวเมนต์ --base_root หรือเลือกตำแหน่งเฉพาะสำหรับแต่ละรายการด้วยอาร์กิวเมนต์ที่เกี่ยวข้อง (เช่น --logs_root )
เรารวมสคริปต์เพื่อเรียกใช้ Biggan-Deep แต่เรายังไม่ได้ฝึกฝนแบบจำลองที่ใช้พวกเขาอย่างเต็มที่ดังนั้นให้พิจารณาพวกเขาที่ยังไม่ทดลอง นอกจากนี้เรายังรวมสคริปต์เพื่อเรียกใช้โมเดลบน CIFAR และเพื่อเรียกใช้ Sa-Gan (กับ EMA) และ SN-GAN บน Imagenet รหัส SA-GAN จะถือว่าคุณมี 4xtitanx (หรือเทียบเท่าในแง่ของ GPU RAM) และจะทำงานด้วยขนาดแบทช์ 128 และ 2 การสะสมการไล่ระดับสี
repo นี้ใช้เครือข่าย Pytorch In Inception ที่สร้างขึ้นเพื่อคำนวณคือและ FID คะแนนเหล่านี้แตกต่างจากคะแนนที่คุณจะได้รับการใช้รหัสการลงทะเบียน TF อย่างเป็นทางการและมีไว้เพื่อการตรวจสอบเท่านั้น! เรียกใช้ sample.py บนโมเดลของคุณด้วยอาร์กิวเมนต์ --sample_npz จากนั้นเรียกใช้ inception_tf13 เพื่อคำนวณ tensorflow จริงคือ โปรดทราบว่าคุณจะต้องมี TensorFlow 1.3 หรือติดตั้งก่อนหน้านี้เนื่องจาก TF1.4+ Breaks ต้นฉบับคือรหัส
เรารวมจุดตรวจสอบแบบจำลองสองแบบ (พร้อม G, D, สำเนา EMA ของ G, Optimizers และ State DICT):
แบบจำลองที่ได้รับการฝึกฝนสำหรับ Places-365 เร็ว ๆ นี้
repo นี้ยังมีสคริปต์สำหรับการพอร์ตน้ำหนักตัวกำเนิด TFHUB Biggan ดั้งเดิมไปยัง Pytorch ดูสคริปต์ในโฟลเดอร์ TFHUB สำหรับรายละเอียดเพิ่มเติม

หากคุณต้องการกลับมาฝึกซ้อมที่ถูกขัดจังหวะหรือปรับแต่งโมเดลที่ผ่านการฝึกอบรมมาก่อนให้เรียกใช้สคริปต์การเปิดตัวเดียวกัน แต่ด้วยการเพิ่มอาร์กิวเมนต์ --resume ความสัมพันธ์ ชื่อการทดลองจะถูกสร้างขึ้นโดยอัตโนมัติจากการกำหนดค่า แต่สามารถแทนที่ได้โดยใช้ --experiment_name arg (ตัวอย่างเช่นหากคุณต้องการปรับแต่งโมเดลโดยใช้การตั้งค่าเครื่องมือเพิ่มประสิทธิภาพที่แก้ไขแล้ว)
ในการเตรียมชุดข้อมูลของคุณเองคุณจะต้องเพิ่มลงใน Datasets.py และแก้ไขคำสั่งความสะดวกสบายใน utils.py (dset_dict, imsize_dict, root_dict, nclass_dict, classes_per_sheet_dict) เพื่อให้มีเมตาดาต้าที่เหมาะสมสำหรับข้อมูลของคุณ ทำซ้ำกระบวนการใน PREPAL_DATA.SH (เป็นทางเลือกผลิตสำเนาที่ประมวลผลล่วงหน้า HDF5 และคำนวณช่วงเวลาเริ่มต้นสำหรับ FID)
โดยค่าเริ่มต้นสคริปต์การฝึกอบรมจะบันทึกจุดตรวจที่ดีที่สุด 5 อันดับแรกซึ่งวัดจากคะแนนเริ่มต้น สำหรับชุดข้อมูลอื่นนอกเหนือจาก Imagenet คะแนนการลงทะเบียนอาจเป็นตัวชี้วัดคุณภาพที่แย่มากดังนั้นคุณอาจต้องการใช้ --which_best FID แทน
หากต้องการใช้ฟังก์ชั่นการฝึกอบรมของคุณเอง (เช่นฝึก Bigvae): แก้ไข train_fns.gan_training_function หรือเพิ่มรถไฟขบวนใหม่ FN และเพิ่มหลังจาก if config['which_train_fn'] == 'GAN': line in train.py
--num_G_SVs รหัสนี้ได้รับการออกแบบจากพื้นดินเพื่อใช้เป็นฐานที่ขยายได้และแฮ็กได้สำหรับรหัสการวิจัยเพิ่มเติม เราได้คิดมากเป็นจำนวนมากในการทำให้แน่ใจว่านามธรรมนั้นมีความหนา ที่เหมาะสม สำหรับการวิจัย-ไม่หนาพอที่จะยอมรับไม่ได้ แต่ไม่ผอมเท่าที่ไร้ประโยชน์ แนวคิดหลักคือถ้าคุณต้องการทดสอบกับการตั้งค่า SOTA และทำการปรับเปลี่ยน (ลองใช้ฟังก์ชั่นการสูญเสียใหม่สถาปัตยกรรมบล็อกการใส่ใจในตัวเอง ฯลฯ ) คุณควรทำอย่างง่ายดายเพียงแค่วางรหัสของคุณในหนึ่งหรือสองสถานที่โดยไม่ต้องกังวลเกี่ยวกับส่วนที่เหลือของ codebase สิ่งต่าง ๆ เช่นการใช้ตัวเอง which_conv และ functools.partial ในคำจำกัดความโมเดล biggan.py ถูกนำมารวมกันในใจเช่นเดียวกับการออกแบบของการสืบทอดระดับสเปกตรัมระดับ
ด้วยที่กล่าวว่านี่เป็นรหัสฐานที่ค่อนข้างใหญ่สำหรับโครงการเดียว ในขณะที่เราพยายามอย่างถี่ถ้วนกับความคิดเห็นหากมีบางสิ่งที่คุณคิดว่าอาจชัดเจนขึ้นเขียนดีขึ้นหรือ refactored ดีกว่าโปรดอย่าลังเลที่จะยกปัญหาหรือคำขอดึง
ต้องการทำงานหรือปรับปรุงรหัสนี้? มีสองสิ่งที่ repo นี้จะได้รับประโยชน์จาก แต่ยังไม่ได้ผล
ดูไดเรกทอรีนี้สำหรับฉลาก Imagenet
หากคุณใช้รหัสนี้โปรดอ้างอิง
@inproceedings{
brock2018large,
title={Large Scale {GAN} Training for High Fidelity Natural Image Synthesis},
author={Andrew Brock and Jeff Donahue and Karen Simonyan},
booktitle={International Conference on Learning Representations},
year={2019},
url={https://openreview.net/forum?id=B1xsqj09Fm},
}
ขอบคุณ Google สำหรับการบริจาคเครดิตคลาวด์ที่ใจกว้าง
Syncbn โดย Jiayuan Mao และ Tete Xiao
แถบความคืบหน้ามาจาก Jan Schlüter
Test Metrics Logger จาก VoxNet
การใช้ Pytorch ของ COV จาก Modar M. Alfadly
Pytorch Fast Matrix SQRT สำหรับ FID จาก Tsung-yu Lin และ Subhransu Maji
รหัสคะแนนการลงทะเบียนเรียน Tensorflow จาก Openai ที่ได้รับการปรับปรุง-GAN