CTCDecode هو تنفيذ فك تشفير البحث عن شعاع CTC (CONCLINEST الزماني) لتفكك PYTORCH. تم استعارة C ++ رمزًا حرًا من DeepSpeesh Paddle Paddles. ويشمل دعم هداف قابلة للتبديل تمكين البحث القياسي في الحزمة ، وفك تشفير KENLM. إذا كنت جديدًا على مفاهيم CTC و Search Beam ، فيرجى زيارة قسم الموارد حيث نربط بعض البرامج التعليمية التي تشرح سبب الحاجة إليها.
المكتبة مستقلة إلى حد كبير وتتطلب فقط pytorch. يتطلب بناء مكتبة C ++ GCC أو Clang. يتم أيضًا تضمين دعم نمذجة اللغة Kenlm اختياريًا وتمكينه افتراضيًا.
يعمل التثبيت أدناه مع Google Colab.
# get the code
git clone --recursive https://github.com/parlance/ctcdecode.git
cd ctcdecode && pip install . from ctcdecode import CTCBeamDecoder
decoder = CTCBeamDecoder (
labels ,
model_path = None ,
alpha = 0 ,
beta = 0 ,
cutoff_top_n = 40 ,
cutoff_prob = 1.0 ,
beam_width = 100 ,
num_processes = 4 ,
blank_id = 0 ,
log_probs_input = False
)
beam_results , beam_scores , timesteps , out_lens = decoder . decode ( output )CTCBeamDecoderlabels هي الرموز التي استخدمتها لتدريب النموذج الخاص بك. يجب أن تكون في نفس الترتيب مثل المخرجات الخاصة بك. على سبيل المثال ، إذا كانت الرموز المميزة الخاصة بك هي الحروف الإنجليزية واستخدمت 0 كرمز فارغ الخاص بك ، فستمرر في القائمة ("_ ABCDEFGHIJKLMOPQRSTUVWXYZ") كوسيطة إلى التسمياتmodel_path هو المسار إلى نموذج لغة Kenlm الخارجي (LM). الافتراضي هو لا شيء.alpha المرتبط باحتمالات LMS. وزن 0 يعني أن LM ليس له أي تأثير.beta المرتبط بعدد الكلمات داخل شعاعنا.cutoff_top_n رقم قطع في التقليم. سيتم استخدام أحرف Cutoff_TOP_N العليا فقط مع أعلى احتمال في المفردات في البحث عن الشعاع.cutoff_prob احتمال قطع في التقليم. 1.0 يعني عدم التقليم.beam_width يتحكم في مدى عرض البحث عن الشعاع. من المرجح أن تجد القيم الأعلى الحزم العليا ، ولكنها ستجعل بحثك عن شعاعك أبطأ بشكل كبير. علاوة على ذلك ، كلما طالت مخرجاتك ، كلما استغرقت المزيد من الحزم الكبيرة. هذه معلمة مهمة تمثل مفاضلة تحتاج إلى صنعها بناءً على مجموعة البيانات واحتياجاتك.num_processes يموضع الدفعة باستخدام العمال Num_processes. ربما تريد تمرير عدد وحدات المعالجة المركزية التي يحتوي عليها جهاز الكمبيوتر الخاص بك. يمكنك العثور على هذا في Python مع import multiprocessing ثم n_cpus = multiprocessing.cpu_count() . الافتراضي 4.blank_id يجب أن يكون هذا فهرس الرمز المميز CTC فارغ (ربما 0).log_probs_input إذا كانت مخرجاتك قد مرت عبر softmax وتمثل الاحتمالات ، فيجب أن يكون هذا خطأ ، إذا تم تمريره عبر logsoftmax وتمثل احتمال السجل السلبي ، فأنت بحاجة إلى المرور بشكل صحيح. إذا كنت لا تفهم هذا ، فقم بتشغيل print(output[0][0].sum()) ، إذا كان رقمًا سالبًا ، فمن المحتمل أن تحصل على nll وتحتاج إلى المرور بشكل صحيح ، إذا كان ذلك يلخص إلى ~ 1.0 ، فيجب أن تمر خاطئًا. الافتراضي خطأ.decodeoutput هو تنشيط الإخراج من النموذج الخاص بك. إذا كان الإخراج الخاص بك قد مرت من خلال طبقة SoftMax ، فلن تحتاج إلى تغييره (باستثناء ربما للانتقال) ، ولكن إذا كان output الخاص بك يمثل احتمالات سجل سلبية (سجلات RAW) ، فأنت إما بحاجة إلى تمريره من خلال torch.nn.functional.softmax أو يمكنك تمرير log_probs_input=False في Decon. يجب أن يكون إخراجك محجوزًا x n_timesteps x n_labels بحيث تحتاج إلى تحويله قبل نقله إلى وحدة فك الترميز. لاحظ أنه إذا قمت بتمرير الأشياء بالترتيب الخاطئ ، فمن المحتمل أن يتم تشغيل البحث عن الشعاع ، فسوف تحصل على نتائج هراء.decode 4 أشياء يتم إرجاعها من decode
beam_results - الشكل: BatchSize X n_beams x n_timesteps مجموعة تحتوي على سلسلة من الأحرف (هذه هي ints ، لا تزال بحاجة إلى فك تشفيرها إلى النص الخاص بك) تمثل نتائج من بحث شعاع معين. لاحظ أن الحزم تكون دائمًا أقصر تقريبًا من العدد الإجمالي للأوقات الزمنية ، والبيانات الإضافية غير عادية ، بحيث ترى الحزمة العلوية (كعلامات int) من العنصر الأول في الدفعة ، تحتاج إلى تشغيل beam_results[0][0][:out_len[0][0]] .beam_scores - الشكل: BatchSize X n_beams مجموعة مع درجة CTC التقريبية لكل شعاع (انظر إلى الكود هنا لمزيد من المعلومات). إذا كان هذا صحيحًا ، فيمكنك الحصول على ثقة النموذج في أن الحزمة صحيحة مع p=1/np.exp(beam_score) .timesteps - الشكل: BatchSize x n_beams timestep التي يكون لها حرف الإخراج التاسع احتمال الذروة. يمكن استخدامها كمحاذاة بين الصوت والنسخة.out_lens - الشكل: BatchSize x n_beams. out_lens[i][j] هو طول JTH Beam_result ، من البند الأول من مجموعتك. from ctcdecode import OnlineCTCBeamDecoder
decoder = OnlineCTCBeamDecoder (
labels ,
model_path = None ,
alpha = 0 ,
beta = 0 ,
cutoff_top_n = 40 ,
cutoff_prob = 1.0 ,
beam_width = 100 ,
num_processes = 4 ,
blank_id = 0 ,
log_probs_input = False
)
state1 = ctcdecode . DecoderState ( decoder )
probs_seq = torch . FloatTensor ([ probs_seq ])
beam_results , beam_scores , timesteps , out_seq_len = decoder . decode ( probs_seq [:, : 2 ], [ state1 ], [ False ])
beam_results , beam_scores , timesteps , out_seq_len = decoder . decode ( probs_seq [:, 2 :], [ state1 ], [ True ])يقوم وحدة فك الترميز عبر الإنترنت بنسخ واجهة CTCBeamDecoder ، ولكنها تتطلب تسلسل الحالات وتسلسل IS_EOS_S.
يتم استخدام الحالات لتجميع تسلسل القطع ، كل مقابلة لمصدر بيانات واحد. يخبر IS_EOS_S وحدة فك الترميز ما إذا كانت القطع قد توقفت عن دفعها إلى الحالة المقابلة.
احصل على الحزمة العلوية للعنصر الأول في الدُفعة beam_results[0][0][:out_len[0][0]]
احصل على أفضل 50 عوارضًا للعنصر الأول في الدفعة الخاصة بك
for i in range ( 50 ):
print ( beam_results [ 0 ][ i ][: out_len [ 0 ][ i ]]) ملاحظة ، ستكون هذه قائمة من ints التي تحتاج إلى فك تشفير. من المحتمل أن يكون لديك بالفعل وظيفة لفك تشفيرها من int إلى النص ، ولكن إذا لم يكن بإمكانك فعل شيء مثل. ". join [abels [n CTCBeamDecoder "".join[labels[n] for n in beam_results[0][0][:out_len[0][0]]]