Model bahasa pra-terlatih berskala besar telah terbukti bermanfaat dalam meningkatkan kealamian model teks-ke-pidato (TTS) dengan memungkinkan mereka menghasilkan pola prosodik yang lebih naturalistik. Namun, model-model ini biasanya tingkat kata atau tingkat-fonem dan dilatih bersama dengan fonem, membuatnya tidak efisien untuk tugas TTS hilir di mana hanya fonem yang diperlukan. Dalam karya ini, kami mengusulkan Bert tingkat fonem (PL-BERT) dengan tugas dalih untuk memprediksi grafem yang sesuai bersama dengan prediksi fonem bertopeng biasa. Evaluasi subyektif menunjukkan bahwa encoder Bert level fonem kami telah secara signifikan meningkatkan skor opini rata-rata (MOS) dari naturitas peringkat pidato yang disintesis dibandingkan dengan styelts styletts styeletts canggih (SOTA) pada teks out-of-distribusi (OOD).
Kertas: https://arxiv.org/abs/2301.08810
Sampel audio: https://pl-bert.github.io/
git clone https://github.com/yl4579/PL-BERT.git
cd PL-BERTconda create --name BERT python=3.8
conda activate BERT
python -m ipykernel install --user --name BERT --display-name " BERT "pip install pandas singleton-decorator datasets " transformers<4.33.3 " accelerate nltk phonemizer sacremoses pebbleSilakan merujuk ke notebook preprocess.ipynb untuk detail lebih lanjut. Preprocessing hanya untuk dataset wikipedia bahasa Inggris. Saya akan membuat cabang baru untuk orang Jepang jika saya punya waktu ekstra untuk mendemostrasi pelatihan pada bahasa lain. Anda juga dapat merujuk ke #6 untuk preprocessing dalam bahasa lain seperti Jepang.
Harap jalankan setiap sel di Notebook Train.ipynb. Anda perlu mengubah line config_path = "Configs/config.yml" di sel 2 jika Anda ingin menggunakan file konfigurasi yang berbeda. Kode pelatihan ada di Jupyter Notebook terutama karena EPXeriment awal dilakukan di Jupyter Notebook, tetapi Anda dapat dengan mudah menjadikannya skrip Python jika Anda mau.
Berikut adalah contoh cara menggunakannya untuk styletts finetuning. Anda dapat menggunakannya untuk model TTS lain dengan mengganti encoder teks dengan PL-Bert pra-terlatih.
from transformers import AlbertConfig , AlbertModel
log_dir = "YOUR PL-BERT CHECKPOINT PATH"
config_path = os . path . join ( log_dir , "config.yml" )
plbert_config = yaml . safe_load ( open ( config_path ))
albert_base_configuration = AlbertConfig ( ** plbert_config [ 'model_params' ])
bert = AlbertModel ( albert_base_configuration )
files = os . listdir ( log_dir )
ckpts = []
for f in os . listdir ( log_dir ):
if f . startswith ( "step_" ): ckpts . append ( f )
iters = [ int ( f . split ( '_' )[ - 1 ]. split ( '.' )[ 0 ]) for f in ckpts if os . path . isfile ( os . path . join ( log_dir , f ))]
iters = sorted ( iters )[ - 1 ]
checkpoint = torch . load ( log_dir + "/step_" + str ( iters ) + ".t7" , map_location = 'cpu' )
state_dict = checkpoint [ 'net' ]
from collections import OrderedDict
new_state_dict = OrderedDict ()
for k , v in state_dict . items ():
name = k [ 7 :] # remove `module.`
if name . startswith ( 'encoder.' ):
name = name [ 8 :] # remove `encoder.`
new_state_dict [ name ] = v
bert . load_state_dict ( new_state_dict )
nets = Munch ( bert = bert ,
# linear projection to match the hidden size (BERT 768, StyleTTS 512)
bert_encoder = nn . Linear ( plbert_config [ 'model_params' ][ 'hidden_size' ], args . hidden_dim ),
predictor = predictor ,
decoder = decoder ,
pitch_extractor = pitch_extractor ,
text_encoder = text_encoder ,
style_encoder = style_encoder ,
text_aligner = text_aligner ,
discriminator = discriminator ) # for stability
for g in optimizer . optimizers [ 'bert' ]. param_groups :
g [ 'betas' ] = ( 0.9 , 0.99 )
g [ 'lr' ] = 1e-5
g [ 'initial_lr' ] = 1e-5
g [ 'min_lr' ] = 0
g [ 'weight_decay' ] = 0.01 bert_dur = model . bert ( texts , attention_mask = ( ~ text_mask ). int ()). last_hidden_state
d_en = model . bert_encoder ( bert_dur ). transpose ( - 1 , - 2 )
d , _ = model . predictor ( d_en , s ,
input_lengths ,
s2s_attn_mono ,
m )Baris 257:
_ , p = model . predictor ( d_en , s ,
input_lengths ,
s2s_attn_mono ,
m )dan baris 415:
bert_dur = model . bert ( texts , attention_mask = ( ~ text_mask ). int ()). last_hidden_state
d_en = model . bert_encoder ( bert_dur ). transpose ( - 1 , - 2 )
d , p = model . predictor ( d_en , s ,
input_lengths ,
s2s_attn_mono ,
m ) optimizer . step ( 'bert_encoder' )
optimizer . step ( 'bert' )PL-Bert pra-terlatih di Wikipedia untuk langkah 1m dapat diunduh di: tautan PL-Bert.
Demo pada dataset LJSPEECH bersama dengan repo Styletts yang telah dimodifikasi dan model pra-terlatih dapat diunduh di sini: tautan styletts. File zip ini berisi modifikasi kode di atas, model PL-BERT yang terlatih yang tercantum di atas, styletts pra-terlatih dengan stylets pra-terlatih dengan pl-BERT dan hifigan pra-terlatih pada ljspeech dari repo styletts.