tf seq2seq
1.0.0
核心構建塊是RNN編碼器架構和注意機制。
該軟件包主要使用最新(1.2)TF.Contrib.Seq2Seq模塊實現
包裝支持
對於sample_data.src和sample_data.trg的預處理原始並行數據,只需運行
cd data /
. / preprocess . sh src trg sample_data $ { max_seq_len }運行上面的代碼執行廣泛使用的機器翻譯(MT)的預處理步驟。
訓練SEQ2SEQ模型,
$ python train . py -- cell_type 'lstm'
-- attention_type 'luong'
-- hidden_units 1024
-- depth 2
-- embedding_size 500
-- num_encoder_symbols 30000
-- num_decoder_symbols 30000 ...為了運行訓練有素的解碼模型,
$ python decode . py -- beam_width 5
-- decode_batch_size 30
-- model_path $PATH_TO_A_MODEL_CHECKPOINT ( e . g . model / translate . ckpt - 100 )
-- max_decode_step 300
-- write_n_best False
-- decode_input $PATH_TO_DECODE_INPUT
-- decode_output $PATH_TO_DECODE_OUTPUT
如果--beam_width=1 ,則在每個時間步長執行貪婪的解碼。
數據參數
--source_vocabulary :源詞彙的途徑--target_vocabulary :目標詞彙的途徑--source_train_data :源培訓數據的途徑--target_train_data :目標訓練數據的路徑--source_valid_data :通往源驗證數據的路徑--target_valid_data :目標驗證數據的路徑網絡參數
--cell_type :用於編碼器和解碼器的RNN單元格(默認:LSTM)--attention_type :注意機制(Bahdanau,Luong),(默認:Bahdanau)--depth :模型中每一層的隱藏單元數量(默認:2)--embedding_size :嵌入編碼器和解碼器輸入的尺寸(默認:500)--num_encoder_symbols :要使用的源詞彙大小(默認:30000)--num_decoder_symbols :要使用的目標詞彙大小(默認:30000)--use_residual :使用圖層之間的殘差連接(默認:true)--attn_input_feeding :在註意解碼器中使用輸入餵養方法(Luong等,2015)(默認:true)--use_dropout :在RNN單元格輸出中使用輟學(默認:true)--dropout_rate :單元輸出的輟學概率(0.0:no dropout)(默認:0.3)訓練參數
--learning_rate :模型中每一層的隱藏單元數(默認:0.0002)--max_gradient_norm :此規範的剪輯梯度(默認1.0)--batch_size :批次大小--max_epochs :最大訓練時期--max_load_batches :一次預取的最大批次數。--max_seq_length :最大序列長度--display_freq :顯示培訓狀態每次迭代--save_freq :保存模型檢查點,每一次迭代--valid_freq :評估模型每一個迭代:有效_data需要--optimizer :用於培訓的優化器:( Adadelta,Adam,RMSProp)(默認:Adam)--model_dir :保存模型檢查點的路徑--model_name :用於模型檢查點的文件名--shuffle_each_epoch :每個時期的洗牌培訓數據集(默認:true)--sort_by_length :按其目標序列長度對預取的minibatches進行排序(默認:true)解碼參數
--beam_width :BeamSearch中使用的光束寬度(默認:1)--decode_batch_size :用於解碼的批處理大小--max_decode_step :解碼中的最大時間步長限制(默認:500)--write_n_best :Write BeamSearch n-best列表(n = beam_width)(默認值:false)--decode_input :輸入文件路徑去解碼--decode_output :解碼輸出的輸出文件路徑運行時參數
--allow_soft_placement :允許設備軟位置--log_device_placement :OPS在設備上的日誌放置該實施基於以下項目:
有關任何評論和反饋,請給我發送電子郵件至[email protected]或在此處打開問題。