CTCDecode เป็นการใช้งานของ CTC (การจำแนกประเภทการเชื่อมต่อชั่วคราว) การถอดรหัสการค้นหาลำแสงสำหรับ pytorch รหัส C ++ ยืมอย่างอิสระจาก Paddle Paddles 'Deepspeech มันรวมถึงการสนับสนุนผู้ทำประตูที่เปลี่ยนได้ซึ่งเปิดใช้งานการค้นหาลำแสงมาตรฐานและการถอดรหัสที่ใช้ Kenlm หากคุณยังใหม่กับแนวคิดของการค้นหา CTC และคานโปรดไปที่ส่วนทรัพยากรที่เราเชื่อมโยงบทเรียนสองสามข้ออธิบายว่าทำไมพวกเขาถึงจำเป็น
ห้องสมุดส่วนใหญ่อยู่ในตัวเองและต้องการเพียง pytorch การสร้างห้องสมุด C ++ ต้องใช้ GCC หรือ Clang การสนับสนุนการสร้างแบบจำลองภาษา Kenlm ยังรวมอยู่ด้วยและเปิดใช้งานโดยค่าเริ่มต้น
การติดตั้งด้านล่างยังใช้ได้กับ Google Colab
# get the code
git clone --recursive https://github.com/parlance/ctcdecode.git
cd ctcdecode && pip install . from ctcdecode import CTCBeamDecoder
decoder = CTCBeamDecoder (
labels ,
model_path = None ,
alpha = 0 ,
beta = 0 ,
cutoff_top_n = 40 ,
cutoff_prob = 1.0 ,
beam_width = 100 ,
num_processes = 4 ,
blank_id = 0 ,
log_probs_input = False
)
beam_results , beam_scores , timesteps , out_lens = decoder . decode ( output )CTCBeamDecoderlabels เป็นโทเค็นที่คุณใช้ในการฝึกอบรมแบบจำลองของคุณ พวกเขาควรอยู่ในลำดับเดียวกับผลลัพธ์ของคุณ ตัวอย่างเช่นหากโทเค็นของคุณเป็นตัวอักษรภาษาอังกฤษและคุณใช้ 0 เป็นโทเค็นว่างเปล่าของคุณคุณจะผ่านรายการ ("_ abcdefghijklmopqrstuvwxyz"model_path เป็นเส้นทางไปสู่รูปแบบภาษาเคนลม์ภายนอก (LM) ของคุณ ค่าเริ่มต้นคือไม่มีalpha ที่เกี่ยวข้องกับความน่าจะเป็นของ LMS น้ำหนัก 0 หมายถึง LM ไม่มีผลbeta ที่เกี่ยวข้องกับจำนวนคำภายในคานของเราcutoff_top_n หมายเลขคัตออฟในการตัดแต่ง เฉพาะอักขระ cutoff_top_n ด้านบนที่มีความน่าจะเป็นสูงสุดในคำศัพท์เท่านั้นที่จะใช้ในการค้นหาลำแสงcutoff_prob cutoff ในการตัดแต่งกิ่ง 1.0 หมายความว่าไม่มีการตัดแต่งกิ่งbeam_width สิ่งนี้ควบคุมว่าการค้นหาลำแสงนั้นกว้างแค่ไหน ค่าที่สูงขึ้นมีแนวโน้มที่จะพบคานด้านบน แต่พวกเขาจะทำให้การค้นหาลำแสงของคุณช้าลงอย่างทวีคูณ ยิ่งไปกว่านั้นเอาท์พุทของคุณนานเท่าไหร่คานขนาดใหญ่ก็จะยิ่งใช้เวลามากขึ้นเท่านั้น นี่เป็นพารามิเตอร์ที่สำคัญที่แสดงถึงการแลกเปลี่ยนที่คุณต้องทำตามชุดข้อมูลและความต้องการของคุณnum_processes ขนานกับแบทช์โดยใช้คนงาน num_processes คุณอาจต้องการผ่านจำนวนซีพียูที่คอมพิวเตอร์ของคุณมี คุณสามารถค้นหาสิ่งนี้ได้ใน Python ด้วย import multiprocessing จากนั้น n_cpus = multiprocessing.cpu_count() ค่าเริ่มต้น 4.blank_id นี่ควรเป็นดัชนีของ CTC blank token (อาจเป็น 0)log_probs_input หากเอาต์พุตของคุณผ่าน softmax และแสดงถึงความน่าจะเป็นสิ่งนี้ควรเป็นเท็จหากพวกเขาผ่าน logsoftmax และแสดงถึงโอกาสในการบันทึกเชิงลบคุณต้องส่งผ่านจริง หากคุณไม่เข้าใจสิ่งนี้ให้เรียกใช้ print(output[0][0].sum()) หากเป็นจำนวนลบคุณอาจมี NLL และต้องผ่านจริงถ้ามันรวมถึง ~ 1.0 คุณควรผ่านเท็จ ค่าเริ่มต้นเท็จdecodeoutput ควรเป็นการเปิดใช้งานเอาต์พุตจากโมเดลของคุณ หากเอาต์พุตของคุณผ่านเลเยอร์ Softmax คุณไม่ควรเปลี่ยนมัน (ยกเว้นอาจจะเปลี่ยน) แต่ถ้า output ของคุณแสดงถึงความน่าจะเป็นในการบันทึกเชิงลบ (บันทึกดิบ) คุณจะต้องส่งผ่าน torch.nn.functional.softmax log_probs_input=False เอาต์พุตของคุณควรเป็น batchsize x n_timesteps x n_labels ดังนั้นคุณอาจต้องเปลี่ยนมันก่อนที่จะส่งไปยังตัวถอดรหัส โปรดทราบว่าหากคุณผ่านสิ่งต่าง ๆ ตามลำดับที่ไม่ถูกต้องการค้นหาลำแสงอาจยังคงทำงานอยู่คุณจะได้รับผลลัพธ์ที่ไร้สาระdecode 4 สิ่งที่ถูกส่งคืนจาก decode
beam_results - รูปร่าง: batchsize x n_beams x n_timesteps ชุดที่มีชุดอักขระ (นี่คือ ints คุณยังต้องถอดรหัสกลับไปยังข้อความของคุณ) แสดงผลลัพธ์จากการค้นหาลำแสงที่กำหนด โปรดทราบว่าคานนั้นสั้นกว่าจำนวนเวลาทั้งหมดเสมอและข้อมูลเพิ่มเติมนั้นไม่ได้รับความรู้สึกดังนั้นเพื่อดูลำแสงด้านบน (เป็นฉลาก int) จากรายการแรกในชุดคุณต้องเรียกใช้ beam_results[0][0][:out_len[0][0]]beam_scores - รูปร่าง: batchsize x n_beams ชุดที่มีคะแนน CTC โดยประมาณของแต่ละลำแสง (ดูรหัสที่นี่สำหรับข้อมูลเพิ่มเติม) หากสิ่งนี้เป็นจริงคุณจะได้รับความมั่นใจของโมเดลว่าลำแสงนั้นถูกต้องด้วย p=1/np.exp(beam_score)timesteps - รูปร่าง: batchsize x n_beams timestep ที่อักขระเอาต์พุต nth มีความน่าจะเป็นสูงสุด สามารถใช้เป็นการจัดตำแหน่งระหว่างเสียงและการถอดเสียงout_lens - รูปร่าง: batchsize x n_beams out_lens[i][j] คือความยาวของ jth beam_result ของรายการ i ของแบทช์ของคุณ from ctcdecode import OnlineCTCBeamDecoder
decoder = OnlineCTCBeamDecoder (
labels ,
model_path = None ,
alpha = 0 ,
beta = 0 ,
cutoff_top_n = 40 ,
cutoff_prob = 1.0 ,
beam_width = 100 ,
num_processes = 4 ,
blank_id = 0 ,
log_probs_input = False
)
state1 = ctcdecode . DecoderState ( decoder )
probs_seq = torch . FloatTensor ([ probs_seq ])
beam_results , beam_scores , timesteps , out_seq_len = decoder . decode ( probs_seq [:, : 2 ], [ state1 ], [ False ])
beam_results , beam_scores , timesteps , out_seq_len = decoder . decode ( probs_seq [:, 2 :], [ state1 ], [ True ])ตัวถอดรหัสออนไลน์กำลังคัดลอกอินเตอร์เฟส CTCBeamDecoder แต่ต้องใช้ลำดับสถานะและลำดับ IS_EOS_S
สถานะจะใช้ในการสะสมลำดับของชิ้นส่วนที่สอดคล้องกับแหล่งข้อมูลเดียว IS_EOS_S บอกตัวถอดรหัสว่าชิ้นนั้นหยุดถูกผลักไปยังสถานะที่เกี่ยวข้องหรือไม่
รับลำแสงด้านบนสำหรับรายการแรกในแบทช์ beam_results[0][0][:out_len[0][0]]
รับคาน 50 อันดับแรกสำหรับรายการแรกในชุดของคุณ
for i in range ( 50 ):
print ( beam_results [ 0 ][ i ][: out_len [ 0 ][ i ]]) หมายเหตุสิ่งเหล่านี้จะเป็นรายการของ ints ที่ต้องการการถอดรหัส คุณน่าจะมีฟังก์ชั่นที่จะถอดรหัสจาก Int เป็นข้อความ แต่ถ้าไม่คุณสามารถทำอะไรได้ "".join[labels[n] for n in beam_results[0][0][:out_len[0][0]]] โดยใช้ป้ายกำกับที่คุณส่งผ่านไปยัง CTCBeamDecoder