CTCDecodeは、PytorchのCTC(コネクショニスト時間分類)ビーム検索デコードの実装です。 C ++コードは、パドルパドルのディープスピーチから自由に借用されました。これには、標準ビーム検索を可能にするスワップ可能なスコアラーサポートと、KENLMベースのデコードが含まれています。 CTCとBeam検索の概念を初めて使用している場合は、必要な理由を説明するいくつかのチュートリアルをリンクするリソースセクションにアクセスしてください。
ライブラリは主に自己完結型であり、Pytorchのみが必要です。 C ++ライブラリを構築するには、GCCまたはClangが必要です。 KENLM言語モデリングサポートもオプションで含まれており、デフォルトで有効になっています。
以下のインストールは、Google Colabでも機能します。
# get the code
git clone --recursive https://github.com/parlance/ctcdecode.git
cd ctcdecode && pip install . from ctcdecode import CTCBeamDecoder
decoder = CTCBeamDecoder (
labels ,
model_path = None ,
alpha = 0 ,
beta = 0 ,
cutoff_top_n = 40 ,
cutoff_prob = 1.0 ,
beam_width = 100 ,
num_processes = 4 ,
blank_id = 0 ,
log_probs_input = False
)
beam_results , beam_scores , timesteps , out_lens = decoder . decode ( output )CTCBeamDecoderへの入力labels 、モデルをトレーニングするために使用したトークンです。それらはあなたの出力と同じ順序でなければなりません。たとえば、トークンが英語の文字であり、空白のトークンとして0を使用した場合、リストに渡されます( "_ abcdefghijklmopqrstuvwxyz")model_pathは、外部KENLM言語モデル(LM)へのパスです。デフォルトはなしです。alpha重み付け。 0の重量は、LMに効果がないことを意味します。beta重量。cutoff_top_nカットオフ番号。ボカブで最も高い確率を持つトップのcutoff_top_n文字のみがビーム検索で使用されます。cutoff_prob剪定におけるカットオフ確率。 1.0は剪定がないことを意味します。beam_widthこれにより、ビーム検索がどれほど広いかを制御します。値が高いほど、トップビームを見つける可能性が高くなりますが、ビーム検索が指数関数的に遅くなります。さらに、出力が長くなればなるほど、大きなビームがかかる時間が長くなります。これは、データセットとニーズに基づいて行う必要があるトレードオフを表す重要なパラメーターです。num_processes 、num_processesワーカーを使用してバッチに並行しています。おそらく、コンピューターが持っているCPUの数を渡したいと思うでしょう。これはimport multiprocessing then n_cpus = multiprocessing.cpu_count()でpythonで見つけることができます。デフォルト4。blank_idこれは、CTCブランクトークンのインデックスである必要があります(おそらく0)。log_probs_input出力がソフトマックスを通過し、確率を表す場合、これはfalseである必要があります。logsoftmaxを通過し、負のログの可能性を表す場合は、trueを渡す必要があります。これを理解していない場合は、 print(output[0][0].sum())を実行します。それが負の数値である場合は、おそらくNLLを持っていて、〜1.0に合計する場合は、falseを渡す必要があります。デフォルトのfalse。decodeメソッドへの入力output 、モデルからの出力アクティベーションである必要があります。出力がSoftMaxレイヤーを通過した場合、変更する必要はありません(転置する場合を除く)が、 output負のログの尤度(生のロジット)を表す場合、追加のtorch.nn.functional.softmaxを通過するか、 log_probs_input=False decoderに渡す必要があります。出力はバッチサイズx n_timesteps x n_labelsである必要があるため、デコーダーに渡す前に転置する必要があります。間違った順序で物事を渡すと、ビーム検索がまだ実行される場合、ナンセンスな結果が戻ってくることに注意してください。decodeメソッドからの出力4つのものがdecodeから返されます
beam_results -shape:batchsize x n_beams x n_timesteps特定のビーム検索の結果を表す一連の文字を含むバッチ(これらはテキストに戻す必要があります)。ビームはタイムステップの総数よりもほとんど常に短く、追加データは非寛容であるため、バッチの最初のアイテムから上部ビーム(intラベルとして)を見るには、 beam_results[0][0][:out_len[0][0]]実行する必要があります。beam_scores -shape:batchsize x n_beamは、各ビームのおおよそのCTCスコアを使用してバッチを使用します(詳細については、こちらのコードをご覧ください)。これが真実であれば、 p=1/np.exp(beam_score)でビームが正しいというモデルの自信を得ることができます。timesteps - 形状:batchsize x n_beams n番目の出力文字がピーク確率を持つタイムステップ。オーディオとトランスクリプトの間のアラインメントとして使用できます。out_lens -shape:batchsize x n_beams。 out_lens[i][j]バッチのアイテムIのjth beam_resultの長さです。 from ctcdecode import OnlineCTCBeamDecoder
decoder = OnlineCTCBeamDecoder (
labels ,
model_path = None ,
alpha = 0 ,
beta = 0 ,
cutoff_top_n = 40 ,
cutoff_prob = 1.0 ,
beam_width = 100 ,
num_processes = 4 ,
blank_id = 0 ,
log_probs_input = False
)
state1 = ctcdecode . DecoderState ( decoder )
probs_seq = torch . FloatTensor ([ probs_seq ])
beam_results , beam_scores , timesteps , out_seq_len = decoder . decode ( probs_seq [:, : 2 ], [ state1 ], [ False ])
beam_results , beam_scores , timesteps , out_seq_len = decoder . decode ( probs_seq [:, 2 :], [ state1 ], [ True ])オンラインデコーダーはCTCBeamDeCoderインターフェイスをコピーしていますが、状態とIS_EOS_Sシーケンスが必要です。
状態は、チャンクのシーケンスを蓄積するために使用され、それぞれが1つのデータソースに対応しています。 IS_EOS_Sは、チャンクが対応する状態に押し出されなくなったかどうかをデコーダーに伝えます。
バッチbeam_results[0][0][:out_len[0][0]]の最初のアイテムのトップビームを取得します。
バッチの最初のアイテムの上位50ビームを取得します
for i in range ( 50 ):
print ( beam_results [ 0 ][ i ][: out_len [ 0 ][ i ]])注意してください、これらはデコードが必要なINTのリストになります。 INTからテキストまでデコードする機能がすでにある可能性がありますが、そうでない場合は、ようなことができます。 "".join[labels[n] for n in beam_results[0][0][:out_len[0][0]]] CTCBeamDecoderに渡したラベルを使用して