OML เป็นเฟรมเวิร์กที่ใช้ Pytorch ในการฝึกอบรมและตรวจสอบความถูกต้องของโมเดลที่ผลิตการฝังที่มีคุณภาพสูง
มีผู้คนจำนวนมากจากมหาวิทยาลัยออกซ์ฟอร์ดและ HSE ที่ใช้ OML ในวิทยานิพนธ์ของพวกเขา [1] [2] [3]
การอัปเดตมุ่งเน้นไปที่องค์ประกอบหลายอย่าง:
เราเพิ่มข้อความสนับสนุน "อย่างเป็นทางการ" และตัวอย่างหลามที่เกี่ยวข้อง (หมายเหตุการสนับสนุนข้อความในท่อยังไม่รองรับ)
เราแนะนำคลาส RetrievalResults ( RR ) - คอนเทนเนอร์เพื่อจัดเก็บรายการแกลเลอรี่ที่เรียกคืนสำหรับการสืบค้นที่กำหนด RR เป็นวิธีที่รวมกันในการแสดงภาพการคาดการณ์และการคำนวณตัวชี้วัด (หากเป็นที่รู้จักความจริงภาคพื้นดิน) นอกจากนี้ยังทำให้การประมวลผลหลังการประมวลผลง่ายขึ้นซึ่งวัตถุ RR ถูกนำมาเป็นอินพุตและ RR_upd อื่นที่ผลิตเป็นเอาต์พุต การมีวัตถุทั้งสองนี้ช่วยให้ผลลัพธ์การดึงข้อมูลการเปรียบเทียบด้วยสายตาหรือโดยตัวชี้วัด นอกจากนี้คุณสามารถสร้างห่วงโซ่ของโพสต์โปรเซสเซอร์ดังกล่าวได้อย่างง่ายดาย
RR เป็นหน่วยความจำที่ได้รับการปรับให้เหมาะสมเนื่องจากการใช้แบตช์: กล่าวอีกนัยหนึ่งมันไม่ได้เก็บเมทริกซ์เต็มของระยะทางคิวรี-แกลเลอรี่ (มันไม่ได้ทำให้การค้นหาโดยประมาณ) เราสร้าง Model และ Dataset เป็นคลาสเดียวที่รับผิดชอบในการประมวลผลตรรกะเฉพาะของโมดอล Model มีหน้าที่ในการตีความขนาดอินพุต: ตัวอย่างเช่น BxCxHxW สำหรับรูปภาพหรือ BxLxD สำหรับลำดับเช่นข้อความ Dataset มีหน้าที่เตรียมรายการ: อาจใช้ Transforms สำหรับรูปภาพหรือ Tokenizer สำหรับข้อความ ฟังก์ชั่นการคำนวณตัวชี้วัดเช่น calc_retrieval_metrics_rr , RetrievalResults , PairwiseReranker และคลาสและฟังก์ชั่นอื่น ๆ จะรวมเป็นหนึ่งเดียวเพื่อทำงานกับวิธีใด ๆ
IVisualizableDataset มีวิธีการ .visaulize() ที่แสดงรายการเดียว หากนำไปใช้งาน RetrievalResults สามารถแสดงเค้าโครงของผลลัพธ์ที่ดึงมาได้ วิธีที่ง่ายที่สุดในการติดตามการเปลี่ยนแปลงคือการอ่านตัวอย่างอีกครั้ง!
วิธีการตรวจสอบที่แนะนำคือการใช้ RetrievalResults และฟังก์ชั่นเช่น calc_retrieval_metrics_rr , calc_fnmr_at_fmr_rr และอื่น ๆ คลาส EmbeddingMetrics ถูกเก็บไว้เพื่อใช้กับ Pytorch Lightning และภายในท่อ หมายเหตุลายเซ็นของวิธีการ EmbeddingMetrics มีการเปลี่ยนแปลงเล็กน้อยดูตัวอย่างสายฟ้าสำหรับสิ่งนั้น
เนื่องจากตรรกะเฉพาะรูปแบบถูก จำกัด อยู่ที่ Dataset จึงไม่เอา PATHS_KEY , X1_KEY , X2_KEY , Y1_KEY และ Y2_KEY อีกต่อไป คีย์ที่ไม่เฉพาะรูปแบบเช่น LABELS_KEY , IS_GALLERY , IS_QUERY_KEY , CATEGORIES_KEY ยังคงใช้งานอยู่
inference_on_images ตอนนี้ inference และทำงานกับรูปแบบใด ๆ
อินเทอร์เฟซที่เปลี่ยนแปลงเล็กน้อยของ Datasets. ตัวอย่างเช่นเรามี IQueryGalleryDataset และ IQueryGalleryLabeledDataset อินเตอร์เฟส สิ่งแรกจะต้องใช้สำหรับการอนุมานอันที่สองสำหรับการตรวจสอบความถูกต้อง นอกจากนี้ยังเพิ่มอินเทอร์เฟซ IVisualizableDataset
ลบภายในบางส่วนเช่น IMetricDDP , EmbeddingMetricsDDP , calc_distance_matrix , calc_gt_mask , calc_mask_to_ignore , apply_mask_to_ignore การเปลี่ยนแปลงเหล่านี้ไม่ควรส่งผลกระทบต่อคุณ นอกจากนี้ยังลบรหัสที่เกี่ยวข้องกับไปป์ไลน์ที่มี triplets ล่วงหน้า
การแยกคุณสมบัติ: ไม่มีการเปลี่ยนแปลงยกเว้นการเพิ่มอาร์กิวเมนต์เสริม - mode_for_checkpointing = (min | max) มันอาจจะเป็นประโยชน์ในการสลับระหว่าง ส่วนล่าง, ดีกว่า และ ยิ่งใหญ่กว่าประเภทของตัวชี้วัดที่ดีกว่า
Paipeline Pairwise-Processing: เปลี่ยนชื่อและอาร์กิวเมนต์เล็กน้อยของการกำหนดค่า postprocessor sub- pairwise_images ตอนนี้เป็น pairwise_reranker และไม่ต้องการการแปลง
คุณอาจคิดว่า "ถ้าฉันต้องการภาพฝังตัวฉันก็สามารถฝึกอบรมตัวแยกประเภทวานิลลาและใช้เลเยอร์สุดท้าย" มันสมเหตุสมผลแล้วเป็นจุดเริ่มต้น แต่มีข้อเสียที่เป็นไปได้หลายประการ:
หากคุณต้องการใช้ Embeddings เพื่อทำการค้นหาคุณต้องคำนวณระยะห่างระหว่างพวกเขา (ตัวอย่างเช่น Cosine หรือ L2) โดยปกติ คุณจะไม่ปรับระยะทางเหล่านี้ให้เหมาะสมระหว่างการฝึกอบรมในการตั้งค่าการจำแนกประเภทโดยตรง ดังนั้นคุณสามารถหวังได้ว่าการฝังตัวสุดท้ายจะมีคุณสมบัติที่ต้องการ
ปัญหาที่สองคือกระบวนการตรวจสอบความถูกต้อง ในการตั้งค่าการค้นหาคุณมักจะใส่ใจว่าเอาต์พุต -N อันดับต้น ๆ ของคุณเกี่ยวข้องกับการสืบค้นอย่างไร วิธีที่เป็นธรรมชาติในการประเมินแบบจำลองคือการจำลองคำขอค้นหาชุดอ้างอิงและใช้หนึ่งในตัวชี้วัดการดึงข้อมูล ดังนั้นจึงไม่มีการรับประกันว่าความแม่นยำในการจำแนกประเภทจะสัมพันธ์กับตัวชี้วัดเหล่านี้
ในที่สุดคุณอาจต้องการใช้ไปป์ไลน์การเรียนรู้แบบเมตริกด้วยตัวเอง มีงานจำนวนมาก : ในการใช้การสูญเสียแฝดคุณต้องสร้างแบทช์ด้วยวิธีที่เฉพาะเจาะจงใช้การทำเหมือง triplets ประเภทต่าง ๆ การติดตามระยะทาง ฯลฯ สำหรับการตรวจสอบความถูกต้องคุณต้องใช้ตัวชี้วัดการดึงซึ่งรวมถึงการสะสมการฝังตัวที่มีประสิทธิภาพ คุณอาจต้องการเห็นภาพคำขอค้นหาของคุณโดยเน้นผลการค้นหาที่ดีและไม่ดี แทนที่จะทำด้วยตัวเองคุณสามารถใช้ OML เพื่อจุดประสงค์ของคุณ
PML เป็นห้องสมุดยอดนิยมสำหรับการเรียนรู้แบบเมตริกและมีคอลเลกชันที่หลากหลายของการสูญเสียคนงานเหมืองระยะทางและตัวลด; นั่นคือเหตุผลที่เราให้ตัวอย่างที่ตรงไปตรงมาของการใช้กับ OML ในขั้นต้นเราพยายามใช้ PML แต่ในที่สุดเราก็มาพร้อมกับห้องสมุดของเราซึ่งเป็นท่อ / สูตรอาหารมากขึ้น นั่นคือวิธีที่ OML แตกต่างจาก PML:
OML มีท่อที่อนุญาตให้โมเดลการฝึกอบรมโดยการเตรียมการกำหนดค่าและข้อมูลของคุณในรูปแบบที่ต้องการ (เหมือนกับการแปลงข้อมูลเป็นรูปแบบ Coco เพื่อฝึกเครื่องตรวจจับจาก MMDetection)
OML มุ่งเน้นไปที่ท่อส่งข้อมูลแบบ end-to-end และกรณีการใช้งานจริง มันมีตัวอย่างตามการกำหนดค่าเกี่ยวกับเกณฑ์มาตรฐานที่เป็นที่นิยมใกล้เคียงกับชีวิตจริง (เช่นภาพถ่ายของผลิตภัณฑ์ของ ID นับพัน) เราพบการผสมผสานที่ดีของพารามิเตอร์ที่ดีในชุดข้อมูลเหล่านี้โมเดลที่ผ่านการฝึกอบรมและเผยแพร่และการกำหนดค่าของพวกเขา ดังนั้นมันจึงทำให้ OML มีสูตรมากกว่า PML และผู้เขียนยืนยันว่าคำกล่าวนี้ว่าห้องสมุดของเขาเป็นชุดเครื่องมือแทนที่จะเป็นสูตรอาหารนอกจากนี้ตัวอย่างใน PML ส่วนใหญ่เป็นชุดข้อมูล CIFAR และ MNIST
OML มีสวนสัตว์ของรุ่นที่ผ่านการฝึกอบรมซึ่งสามารถเข้าถึงได้ง่ายจากรหัสในลักษณะเดียวกับใน torchvision (เมื่อคุณพิมพ์ resnet50(pretrained=True) )
OML ถูกรวมเข้ากับ Pytorch Lightning ดังนั้นเราสามารถใช้พลังของผู้ฝึกสอนได้ สิ่งนี้มีประโยชน์อย่างยิ่งเมื่อเราทำงานกับ DDP ดังนั้นคุณจะเปรียบเทียบตัวอย่าง DDP ของเราและ PMLS หนึ่ง โดยวิธีการที่ PML ยังมีผู้ฝึกสอน แต่ก็ไม่ได้ใช้กันอย่างแพร่หลายในตัวอย่างและใช้ฟังก์ชั่น train / test ที่กำหนดเองแทน
เราเชื่อว่าการมีท่อ, ตัวอย่างที่ไร้เดียงสาและสวนสัตว์ของแบบจำลองที่ผ่านการฝึกอบรมทำให้เกณฑ์รายการเป็นค่าต่ำจริงๆ
ปัญหาการเรียนรู้แบบเมตริก (หรือที่เรียกว่าปัญหา การจำแนกประเภทสุดขีด ) หมายถึงสถานการณ์ที่เรามี ID หลายพันรายการของบางหน่วยงาน แต่มีเพียงไม่กี่ตัวอย่างสำหรับทุกหน่วยงาน บ่อยครั้งที่เราสมมติว่าในระหว่างขั้นตอนการทดสอบ (หรือการผลิต) เราจะจัดการกับหน่วยงานที่มองไม่เห็นซึ่งทำให้เป็นไปไม่ได้ที่จะใช้ไปป์ไลน์การจำแนกวานิลลาโดยตรง ในหลายกรณีที่ได้รับการฝังจะใช้ในการดำเนินการค้นหาหรือการจับคู่ขั้นตอนเหนือพวกเขา
นี่คือตัวอย่างบางส่วนของงานดังกล่าวจาก Sphere Vision Computer:
embedding - เอาต์พุตของโมเดล (หรือที่เรียกว่า features vector หรือ descriptor )query - ตัวอย่างที่ใช้เป็นคำขอในขั้นตอนการดึงข้อมูลgallery set - ชุดของเอนทิตีเพื่อค้นหารายการที่คล้ายกับ query (หรือที่เรียกว่า reference หรือ index )Sampler - อาร์กิวเมนต์สำหรับ DataLoader ซึ่งใช้ในการสร้างแบทช์Miner - วัตถุในการสร้างคู่หรือสามเท่าหลังจากแบทช์ถูกสร้างขึ้นโดย Sampler ไม่จำเป็นต้องสร้างการรวมกันของตัวอย่างภายในชุดปัจจุบันเท่านั้นดังนั้นธนาคารหน่วยความจำอาจเป็นส่วนหนึ่งของ MinerSamples / Labels / Instances - เป็นตัวอย่างลองพิจารณาชุดข้อมูล DeepFashion มันมีรหัสรายการแฟชั่นหลายพันรายการ (เราตั้งชื่อ labels ) และรูปภาพหลายรูปสำหรับแต่ละรายการ ID (เราตั้งชื่อรูปถ่ายแต่ละภาพเป็น instance หรือ sample ) รหัสไอเท็มแฟชั่นทั้งหมดมีกลุ่มของพวกเขาเช่น "กระโปรง", "แจ็คเก็ต", "กางเกงขาสั้น" และอื่น ๆ (เราตั้งชื่อพวกเขา categories ) หมายเหตุเราหลีกเลี่ยงการใช้ class คำเพื่อหลีกเลี่ยงความเข้าใจผิดtraining epoch - ตัวอย่างแบตช์ที่เราใช้สำหรับการสูญเสียที่ใช้การรวมกันมักจะมีความยาวเท่ากับ [number of labels in training dataset] / [numbers of labels in one batch] หมายความว่าเราไม่ได้สังเกตตัวอย่างการฝึกอบรมทั้งหมดที่มีอยู่ในยุคเดียว (ตรงข้ามกับการจำแนกประเภทวานิลลา) แต่เราสังเกตเห็นป้ายกำกับทั้งหมดที่มีอยู่มันอาจจะเทียบได้กับวิธี SOTA (2022 ปี) ปัจจุบันตัวอย่างเช่น HYP-VIT (ไม่กี่คำเกี่ยวกับวิธีการนี้: มันเป็นสถาปัตยกรรม VIT ที่ได้รับการฝึกฝนด้วยการสูญเสียความคมชัด แต่การฝังตัวถูกฉายเข้าไปในพื้นที่ไฮเปอร์โบลิกบางอย่างตามที่ผู้เขียนอ้างว่าพื้นที่ดังกล่าวสามารถอธิบายโครงสร้างที่ซ้อนกันของข้อมูลโลกแห่งความจริงได้
เราได้รับการฝึกฝนสถาปัตยกรรมเดียวกันกับการสูญเสียแฝดการแก้ไขส่วนที่เหลือของพารามิเตอร์: การฝึกอบรมและการแปลงการทดสอบขนาดภาพและเครื่องมือเพิ่มประสิทธิภาพ ดูการกำหนดค่าในสวนสัตว์รุ่น เคล็ดลับอยู่ในฮิวริสติกในคนงานเหมืองและตัวอย่างของเรา:
Sampler ยอดคงเหลือหมวดหมู่จะสร้างแบทช์ที่ จำกัด จำนวนหมวดหมู่ C ในนั้น ตัวอย่างเช่นเมื่อ C = 1 มันใส่แจ็คเก็ตเพียงชุดเดียวและกางเกงยีนส์เท่านั้นลงในชุดอื่น (เพียงตัวอย่าง) มันทำให้คู่ลบโดยอัตโนมัติ: มันมีความหมายมากกว่าสำหรับนางแบบที่จะตระหนักว่าทำไมแจ็คเก็ตสองตัวจึงแตกต่างจากการเข้าใจเหมือนกันเกี่ยวกับแจ็คเก็ตและเสื้อยืด
Hard Triplets Miner ทำให้งานหนักยิ่งขึ้นการรักษาเพียง Triplets ที่ยากที่สุดเท่านั้น
นี่คือ CMC@1 คะแนนสำหรับ 2 เกณฑ์มาตรฐานยอดนิยม ชุดข้อมูล SOP: HYP-VIT-85.9, เรา-86.6 ชุดข้อมูล DeepFashion: HYP-VIT-92.5, เรา-92.1 ดังนั้นการใช้ฮิวริสติกง่าย ๆ และหลีกเลี่ยงคณิตศาสตร์หนักเราสามารถทำงานในระดับ SOTA ได้
การวิจัยล่าสุดใน SSL ได้รับผลลัพธ์ที่ยอดเยี่ยมอย่างแน่นอน ปัญหาคือวิธีการเหล่านี้ต้องใช้การคำนวณจำนวนมหาศาลเพื่อฝึกอบรมแบบจำลอง แต่ในกรอบของเราเราพิจารณากรณีที่พบบ่อยที่สุดเมื่อผู้ใช้เฉลี่ยมี GPU ไม่เกินสองสามตัว
ในขณะเดียวกันก็ไม่ฉลาดที่จะเพิกเฉยต่อความสำเร็จในทรงกลมนี้ดังนั้นเรายังคงใช้ประโยชน์จากมันในสองวิธี:
ไม่คุณไม่ OML เป็นเฟรมเวิร์กที่ไม่เชื่อเรื่องพระเจ้า แม้เราจะใช้ Pytorch Lightning เป็นนักวิ่งวนรอบสำหรับการทดลอง แต่เรายังรักษาความเป็นไปได้ที่จะทำงานทุกอย่างบน Pytorch บริสุทธิ์ ดังนั้นเฉพาะส่วนเล็ก ๆ ของ OML เท่านั้นที่เป็นสายฟ้าและเราเก็บตรรกะนี้แยกต่างหากจากรหัสอื่น (ดู oml.lightning ) แม้ว่าคุณจะใช้ฟ้าผ่าคุณก็ไม่จำเป็นต้องรู้เพราะเราพร้อมที่จะใช้ท่อ
ความเป็นไปได้ในการใช้ pytorch บริสุทธิ์และโครงสร้างแบบแยกส่วนของรหัสออกจากห้องสำหรับใช้ OML กับกรอบที่คุณชื่นชอบหลังจากการใช้งานห่อหุ้มที่จำเป็น
ใช่. ในการเรียกใช้การทดลองด้วยท่อคุณจะต้องเขียนตัวแปลงเป็นรูปแบบของเรา (หมายถึงการเตรียมตาราง .csv ด้วยคอลัมน์ที่กำหนดไว้ล่วงหน้าสองสามคอลัมน์) แค่ไหน!
อาจเป็นไปได้ว่าเรามีรูปแบบที่ได้รับการฝึกอบรมล่วงหน้าที่เหมาะสมสำหรับโดเมนของคุณใน สวนสัตว์รุ่น ของเรา ในกรณีนี้คุณไม่จำเป็นต้องฝึกด้วยซ้ำ
ขณะนี้เราไม่สนับสนุนโมเดลการส่งออกไปยัง ONNX โดยตรง อย่างไรก็ตามคุณสามารถใช้ความสามารถในตัว Pytorch ในตัวเพื่อให้ได้สิ่งนี้ สำหรับข้อมูลเพิ่มเติมโปรดดูปัญหานี้
เอกสาร
บทช่วยสอนเริ่มต้นด้วย: ภาษาอังกฤษ | รัสเซีย ชาวจีน
การสาธิตสำหรับการผัดกระดาษของเรา: Siamese Transformers สำหรับการดึงภาพหลังการประมวลผล
พบกับ OpenMetriclearning (OML) บน MarkTechPost
รายงานการพบปะสังสรรค์ในกรุงเบอร์ลิน: "คอมพิวเตอร์วิสัยทัศน์ในการผลิต" พฤศจิกายน, 2022. ลิงค์
pip install -U open-metric-learning ; # minimum dependencies
pip install -U open-metric-learning[nlp]
pip install -U open-metric-learning[audio]docker pull omlteam/oml:gpu
docker pull omlteam/oml:cpu การสูญเสีย | คนงานเหมือง miner = AllTripletsMiner ()
miner = NHardTripletsMiner ()
miner = MinerWithBank ()
...
criterion = TripletLossWithMiner ( 0.1 , miner )
criterion = ArcFaceLoss ()
criterion = SurrogatePrecision () | ตัวอย่าง labels = train . get_labels ()
l2c = train . get_label2category ()
sampler = BalanceSampler ( labels )
sampler = CategoryBalanceSampler ( labels , l2c )
sampler = DistinctCategoryBalanceSampler ( labels , l2c ) |
รองรับการกำหนดค่า max_epochs : 10
sampler :
name : balance
args :
n_labels : 2
n_instances : 2 | รุ่นที่ผ่านการฝึกอบรมมาก่อน model_hf = AutoModel . from_pretrained ( "roberta-base" )
tokenizer = AutoTokenizer . from_pretrained ( "roberta-base" )
extractor_txt = HFWrapper ( model_hf )
extractor_img = ViTExtractor . from_pretrained ( "vits16_dino" )
transforms , _ = get_transforms_for_pretrained ( "vits16_dino" ) |
การโพสต์ emb = inference ( extractor , dataset )
rr = RetrievalResults . from_embeddings ( emb , dataset )
postprocessor = AdaptiveThresholding ()
rr_upd = postprocessor . process ( rr , dataset ) | โพสต์การประมวลผลโดย NN | กระดาษ embeddings = inference ( extractor , dataset )
rr = RetrievalResults . from_embeddings ( embeddings , dataset )
postprocessor = PairwiseReranker ( ConcatSiamese (), top_n = 3 )
rr_upd = postprocessor . process ( rr , dataset ) |
การตัดไม้ logger = TensorBoardPipelineLogger ()
logger = NeptunePipelineLogger ()
logger = WandBPipelineLogger ()
logger = MLFlowPipelineLogger ()
logger = ClearMLPipelineLogger () | PML from pytorch_metric_learning import losses
criterion = losses . TripletMarginLoss ( 0.2 , "all" )
pred = ViTExtractor ()( data )
criterion ( pred , gts ) |
หมวดหมู่สนับสนุน # train
loader = DataLoader ( CategoryBalanceSampler ())
# validation
rr = RetrievalResults . from_embeddings ()
m . calc_retrieval_metrics_rr ( rr , query_categories ) | การวัดอื่น ๆ embeddigs = inference ( model , dataset )
rr = RetrievalResults . from_embeddings ( embeddings , dataset )
m . calc_retrieval_metrics_rr ( rr , precision_top_k = ( 5 ,))
m . calc_fnmr_at_fmr_rr ( rr , fmr_vals = ( 0.1 ,))
m . calc_topological_metrics ( embeddings , pcf_variance = ( 0.5 ,)) |
ฟ้าผ่า import pytorch_lightning as pl
model = ViTExtractor . from_pretrained ( "vits16_dino" )
clb = MetricValCallback ( EmbeddingMetrics ( dataset ))
module = ExtractorModule ( model , criterion , optimizer )
trainer = pl . Trainer ( max_epochs = 3 , callbacks = [ clb ])
trainer . fit ( module , train_loader , val_loader ) | สายฟ้า DDP clb = MetricValCallback ( EmbeddingMetrics ( val ))
module = ExtractorModuleDDP (
model , criterion , optimizer , train , val
)
ddp = { "devices" : 2 , "strategy" : DDPStrategy ()}
trainer = pl . Trainer ( max_epochs = 3 , callbacks = [ clb ], ** ddp )
trainer . fit ( module ) |
นี่คือตัวอย่างของวิธีการฝึกอบรมตรวจสอบและโพสต์ประมวลผลโมเดลบนชุดข้อมูลหรือข้อความเล็ก ๆ ดูรายละเอียดเพิ่มเติมเกี่ยวกับรูปแบบชุดข้อมูล
| ภาพ | ตำรา |
from torch . optim import Adam
from torch . utils . data import DataLoader
from oml import datasets as d
from oml . inference import inference
from oml . losses import TripletLossWithMiner
from oml . metrics import calc_retrieval_metrics_rr
from oml . miners import AllTripletsMiner
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
from oml . retrieval import RetrievalResults , AdaptiveThresholding
from oml . samplers import BalanceSampler
from oml . utils import get_mock_images_dataset
model = ViTExtractor . from_pretrained ( "vits16_dino" ). to ( "cpu" ). train ()
transform , _ = get_transforms_for_pretrained ( "vits16_dino" )
df_train , df_val = get_mock_images_dataset ( global_paths = True )
train = d . ImageLabeledDataset ( df_train , transform = transform )
val = d . ImageQueryGalleryLabeledDataset ( df_val , transform = transform )
optimizer = Adam ( model . parameters (), lr = 1e-4 )
criterion = TripletLossWithMiner ( 0.1 , AllTripletsMiner (), need_logs = True )
sampler = BalanceSampler ( train . get_labels (), n_labels = 2 , n_instances = 2 )
def training ():
for batch in DataLoader ( train , batch_sampler = sampler ):
embeddings = model ( batch [ "input_tensors" ])
loss = criterion ( embeddings , batch [ "labels" ])
loss . backward ()
optimizer . step ()
optimizer . zero_grad ()
print ( criterion . last_logs )
def validation ():
embeddings = inference ( model , val , batch_size = 4 , num_workers = 0 )
rr = RetrievalResults . from_embeddings ( embeddings , val , n_items = 3 )
rr = AdaptiveThresholding ( n_std = 2 ). process ( rr )
rr . visualize ( query_ids = [ 2 , 1 ], dataset = val , show = True )
print ( calc_retrieval_metrics_rr ( rr , map_top_k = ( 3 ,), cmc_top_k = ( 1 ,)))
training ()
validation () | from torch . optim import Adam
from torch . utils . data import DataLoader
from transformers import AutoModel , AutoTokenizer
from oml import datasets as d
from oml . inference import inference
from oml . losses import TripletLossWithMiner
from oml . metrics import calc_retrieval_metrics_rr
from oml . miners import AllTripletsMiner
from oml . models import HFWrapper
from oml . retrieval import RetrievalResults , AdaptiveThresholding
from oml . samplers import BalanceSampler
from oml . utils import get_mock_texts_dataset
model = HFWrapper ( AutoModel . from_pretrained ( "bert-base-uncased" ), 768 ). to ( "cpu" ). train ()
tokenizer = AutoTokenizer . from_pretrained ( "bert-base-uncased" )
df_train , df_val = get_mock_texts_dataset ()
train = d . TextLabeledDataset ( df_train , tokenizer = tokenizer )
val = d . TextQueryGalleryLabeledDataset ( df_val , tokenizer = tokenizer )
optimizer = Adam ( model . parameters (), lr = 1e-4 )
criterion = TripletLossWithMiner ( 0.1 , AllTripletsMiner (), need_logs = True )
sampler = BalanceSampler ( train . get_labels (), n_labels = 2 , n_instances = 2 )
def training ():
for batch in DataLoader ( train , batch_sampler = sampler ):
embeddings = model ( batch [ "input_tensors" ])
loss = criterion ( embeddings , batch [ "labels" ])
loss . backward ()
optimizer . step ()
optimizer . zero_grad ()
print ( criterion . last_logs )
def validation ():
embeddings = inference ( model , val , batch_size = 4 , num_workers = 0 )
rr = RetrievalResults . from_embeddings ( embeddings , val , n_items = 3 )
rr = AdaptiveThresholding ( n_std = 2 ). process ( rr )
rr . visualize ( query_ids = [ 2 , 1 ], dataset = val , show = True )
print ( calc_retrieval_metrics_rr ( rr , map_top_k = ( 3 ,), cmc_top_k = ( 1 ,)))
training ()
validation () |
เอาท์พุท{ 'active_tri' : 0.125 , 'pos_dist' : 82.5 , 'neg_dist' : 100.5 } # batch 1
{ 'active_tri' : 0.0 , 'pos_dist' : 36.3 , 'neg_dist' : 56.9 } # batch 2
{ 'cmc' : { 1 : 0.75 }, 'precision' : { 5 : 0.75 }, 'map' : { 3 : 0.8 }} | เอาท์พุท{ 'active_tri' : 0.0 , 'pos_dist' : 8.5 , 'neg_dist' : 11.0 } # batch 1
{ 'active_tri' : 0.25 , 'pos_dist' : 8.9 , 'neg_dist' : 9.8 } # batch 2
{ 'cmc' : { 1 : 0.8 }, 'precision' : { 5 : 0.7 }, 'map' : { 3 : 0.9 }} |
ภาพประกอบพิเศษคำอธิบายและเคล็ดลับสำหรับรหัสด้านบน
นี่คือตัวอย่างเวลาการอนุมาน (กล่าวอีกนัยหนึ่งคือการดึงข้อมูลในชุดทดสอบ) รหัสด้านล่างใช้งานได้ทั้งข้อความและรูปภาพ
from oml . datasets import ImageQueryGalleryDataset
from oml . inference import inference
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
from oml . utils import get_mock_images_dataset
from oml . retrieval import RetrievalResults , AdaptiveThresholding
_ , df_test = get_mock_images_dataset ( global_paths = True )
del df_test [ "label" ] # we don't need gt labels for doing predictions
extractor = ViTExtractor . from_pretrained ( "vits16_dino" ). to ( "cpu" )
transform , _ = get_transforms_for_pretrained ( "vits16_dino" )
dataset = ImageQueryGalleryDataset ( df_test , transform = transform )
embeddings = inference ( extractor , dataset , batch_size = 4 , num_workers = 0 )
rr = RetrievalResults . from_embeddings ( embeddings , dataset , n_items = 5 )
rr = AdaptiveThresholding ( n_std = 3.5 ). process ( rr )
rr . visualize ( query_ids = [ 0 , 1 ], dataset = dataset , show = True )
# you get the ids of retrieved items and the corresponding distances
print ( rr )นี่คือตัวอย่างที่การสืบค้นและแกลเลอรี่ประมวลผลแยกต่างหาก
import pandas as pd
from oml . datasets import ImageBaseDataset
from oml . inference import inference
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
from oml . retrieval import RetrievalResults , ConstantThresholding
from oml . utils import get_mock_images_dataset
extractor = ViTExtractor . from_pretrained ( "vits16_dino" ). to ( "cpu" )
transform , _ = get_transforms_for_pretrained ( "vits16_dino" )
paths = pd . concat ( get_mock_images_dataset ( global_paths = True ))[ "path" ]
galleries , queries1 , queries2 = paths [: 20 ], paths [ 20 : 22 ], paths [ 22 : 24 ]
# gallery is huge and fixed, so we only process it once
dataset_gallery = ImageBaseDataset ( galleries , transform = transform )
embeddings_gallery = inference ( extractor , dataset_gallery , batch_size = 4 , num_workers = 0 )
# queries come "online" in stream
for queries in [ queries1 , queries2 ]:
dataset_query = ImageBaseDataset ( queries , transform = transform )
embeddings_query = inference ( extractor , dataset_query , batch_size = 4 , num_workers = 0 )
# for the operation below we are going to provide integrations with vector search DB like QDrant or Faiss
rr = RetrievalResults . from_embeddings_qg (
embeddings_query = embeddings_query , embeddings_gallery = embeddings_gallery ,
dataset_query = dataset_query , dataset_gallery = dataset_gallery
)
rr = ConstantThresholding ( th = 80 ). process ( rr )
rr . visualize_qg ([ 0 , 1 ], dataset_query = dataset_query , dataset_gallery = dataset_gallery , show = True )
print ( rr )Pipelines ให้วิธีการทดลองการเรียนรู้การเรียนรู้ผ่านการเปลี่ยนเฉพาะไฟล์ config สิ่งที่คุณต้องมีคือการเตรียมชุดข้อมูลของคุณในรูปแบบที่ต้องการ
ดูโฟลเดอร์ Pipelines สำหรับรายละเอียดเพิ่มเติม:
นี่คือการรวมที่มีน้ำหนักเบากับโมเดล HuggingFace Transformers คุณสามารถแทนที่ด้วยแบบจำลองโดยพลการอื่น ๆ ที่สืบทอดมาจาก IExtractor
หมายเหตุเราไม่มีสวนสัตว์รุ่นข้อความของเราเองในขณะนี้
pip install open-metric-learning[nlp] from transformers import AutoModel , AutoTokenizer
from oml . models import HFWrapper
model = AutoModel . from_pretrained ( 'bert-base-uncased' ). eval ()
tokenizer = AutoTokenizer . from_pretrained ( 'bert-base-uncased' )
extractor = HFWrapper ( model = model , feat_dim = 768 )
inp = tokenizer ( text = "Hello world" , return_tensors = "pt" , add_special_tokens = True )
embeddings = extractor ( inp )คุณสามารถใช้รูปแบบรูปภาพจากสวนสัตว์ของเราหรือใช้โมเดลอื่น ๆ หลังจากที่คุณสืบทอดมาจาก IExtractor
from oml . const import CKPT_SAVE_ROOT as CKPT_DIR , MOCK_DATASET_PATH as DATA_DIR
from oml . models import ViTExtractor
from oml . registry import get_transforms_for_pretrained
model = ViTExtractor . from_pretrained ( "vits16_dino" ). eval ()
transforms , im_reader = get_transforms_for_pretrained ( "vits16_dino" )
img = im_reader ( DATA_DIR / "images" / "circle_1.jpg" ) # put path to your image here
img_tensor = transforms ( img )
# img_tensor = transforms(image=img)["image"] # for transforms from Albumentations
features = model ( img_tensor . unsqueeze ( 0 ))
# Check other available models:
print ( list ( ViTExtractor . pretrained_models . keys ()))
# Load checkpoint saved on a disk:
model_ = ViTExtractor ( weights = CKPT_DIR / "vits16_dino.ckpt" , arch = "vits16" , normalise_features = False )แบบจำลองที่ได้รับการฝึกฝนโดยเรา ตัวชี้วัดด้านล่างนี้มีราคา 224 x 224 :
| แบบอย่าง | CMC1 | ชุดข้อมูล | น้ำหนัก | การทดลอง |
|---|---|---|---|---|
ViTExtractor.from_pretrained("vits16_inshop") | 0.921 | deepfashion inshop | การเชื่อมโยง | การเชื่อมโยง |
ViTExtractor.from_pretrained("vits16_sop") | 0.866 | ผลิตภัณฑ์ออนไลน์ของ Stanford | การเชื่อมโยง | การเชื่อมโยง |
ViTExtractor.from_pretrained("vits16_cars") | 0.907 | รถยนต์ 196 | การเชื่อมโยง | การเชื่อมโยง |
ViTExtractor.from_pretrained("vits16_cub") | 0.837 | Cub 200 2011 | การเชื่อมโยง | การเชื่อมโยง |
แบบจำลองที่ได้รับการฝึกฝนโดยนักวิจัยคนอื่น ๆ โปรดทราบว่าตัวชี้วัดบางอย่างเกี่ยวกับเกณฑ์มาตรฐานโดยเฉพาะนั้นสูงมากเพราะเป็นส่วนหนึ่งของชุดข้อมูลการฝึกอบรม (เช่น unicom ) ตัวชี้วัดด้านล่างนี้มีราคา 224 x 224:
| แบบอย่าง | ผลิตภัณฑ์ออนไลน์ของ Stanford | deepfashion inshop | Cub 200 2011 | รถยนต์ 196 |
|---|---|---|---|---|
ViTUnicomExtractor.from_pretrained("vitb16_unicom") | 0.700 | 0.734 | 0.847 | 0.916 |
ViTUnicomExtractor.from_pretrained("vitb32_unicom") | 0.690 | 0.722 | 0.796 | 0.893 |
ViTUnicomExtractor.from_pretrained("vitl14_unicom") | 0.726 | 0.790 | 0.868 | 0.922 |
ViTUnicomExtractor.from_pretrained("vitl14_336px_unicom") | 0.745 | 0.810 | 0.875 | 0.924 |
ViTCLIPExtractor.from_pretrained("sber_vitb32_224") | 0.547 | 0.514 | 0.448 | 0.618 |
ViTCLIPExtractor.from_pretrained("sber_vitb16_224") | 0.565 | 0.565 | 0.524 | 0.648 |
ViTCLIPExtractor.from_pretrained("sber_vitl14_224") | 0.512 | 0.555 | 0.606 | 0.707 |
ViTCLIPExtractor.from_pretrained("openai_vitb32_224") | 0.612 | 0.491 | 0.560 | 0.693 |
ViTCLIPExtractor.from_pretrained("openai_vitb16_224") | 0.648 | 0.606 | 0.665 | 0.767 |
ViTCLIPExtractor.from_pretrained("openai_vitl14_224") | 0.670 | 0.675 | 0.745 | 0.844 |
ViTExtractor.from_pretrained("vits16_dino") | 0.648 | 0.509 | 0.627 | 0.265 |
ViTExtractor.from_pretrained("vits8_dino") | 0.651 | 0.524 | 0.661 | 0.315 |
ViTExtractor.from_pretrained("vitb16_dino") | 0.658 | 0.514 | 0.541 | 0.288 |
ViTExtractor.from_pretrained("vitb8_dino") | 0.689 | 0.599 | 0.506 | 0.313 |
ViTExtractor.from_pretrained("vits14_dinov2") | 0.566 | 0.334 | 0.797 | 0.503 |
ViTExtractor.from_pretrained("vits14_reg_dinov2") | 0.566 | 0.332 | 0.795 | 0.740 |
ViTExtractor.from_pretrained("vitb14_dinov2") | 0.565 | 0.342 | 0.842 | 0.644 |
ViTExtractor.from_pretrained("vitb14_reg_dinov2") | 0.557 | 0.324 | 0.833 | 0.828 |
ViTExtractor.from_pretrained("vitl14_dinov2") | 0.576 | 0.352 | 0.844 | 0.692 |
ViTExtractor.from_pretrained("vitl14_reg_dinov2") | 0.571 | 0.340 | 0.840 | 0.871 |
ResnetExtractor.from_pretrained("resnet50_moco_v2") | 0.493 | 0.267 | 0.264 | 0.149 |
ResnetExtractor.from_pretrained("resnet50_imagenet1k_v1") | 0.515 | 0.284 | 0.455 | 0.247 |
ตัวชี้วัดอาจแตกต่างจากเอกสารที่รายงานเนื่องจากเวอร์ชันของการแยกรถไฟ/Val และการใช้งานกล่องขอบเขตอาจแตกต่างกัน
เรายินดีต้อนรับผู้มีส่วนร่วมใหม่! โปรดดูของเรา:
โครงการเริ่มต้นในปี 2020 เป็นโมดูลสำหรับห้องสมุดตัวเร่งปฏิกิริยา ฉันอยากจะขอบคุณคนที่ทำงานกับฉันในโมดูลนั้น: Julia Shenshina, Nikita Balagansky, Sergey Kolesnikov และคนอื่น ๆ
ฉันขอขอบคุณผู้คนที่ยังคงทำงานในท่อนี้ต่อไปเมื่อมันกลายเป็นโครงการแยกต่างหาก: Julia Shenshina, Misha Kindulov, Aron Dik, Aleksei Tarasov และ Verkhovtsev Leonid
ฉันยังต้องการขอบคุณ Newyorker เนื่องจากส่วนหนึ่งของฟังก์ชั่นได้รับการพัฒนา (และใช้) โดยทีมวิสัยทัศน์คอมพิวเตอร์ที่นำโดยฉัน