tf seq2seq
1.0.0
核心构建块是RNN编码器架构和注意机制。
该软件包主要使用最新(1.2)TF.Contrib.Seq2Seq模块实现
包装支持
对于sample_data.src和sample_data.trg的预处理原始并行数据,只需运行
cd data /
. / preprocess . sh src trg sample_data $ { max_seq_len }运行上面的代码执行广泛使用的机器翻译(MT)的预处理步骤。
训练SEQ2SEQ模型,
$ python train . py -- cell_type 'lstm'
-- attention_type 'luong'
-- hidden_units 1024
-- depth 2
-- embedding_size 500
-- num_encoder_symbols 30000
-- num_decoder_symbols 30000 ...为了运行训练有素的解码模型,
$ python decode . py -- beam_width 5
-- decode_batch_size 30
-- model_path $PATH_TO_A_MODEL_CHECKPOINT ( e . g . model / translate . ckpt - 100 )
-- max_decode_step 300
-- write_n_best False
-- decode_input $PATH_TO_DECODE_INPUT
-- decode_output $PATH_TO_DECODE_OUTPUT
如果--beam_width=1 ,则在每个时间步长执行贪婪的解码。
数据参数
--source_vocabulary :源词汇的途径--target_vocabulary :目标词汇的途径--source_train_data :源培训数据的途径--target_train_data :目标训练数据的路径--source_valid_data :通往源验证数据的路径--target_valid_data :目标验证数据的路径网络参数
--cell_type :用于编码器和解码器的RNN单元格(默认:LSTM)--attention_type :注意机制(Bahdanau,Luong),(默认:Bahdanau)--depth :模型中每一层的隐藏单元数量(默认:2)--embedding_size :嵌入编码器和解码器输入的尺寸(默认:500)--num_encoder_symbols :要使用的源词汇大小(默认:30000)--num_decoder_symbols :要使用的目标词汇大小(默认:30000)--use_residual :使用图层之间的残差连接(默认:true)--attn_input_feeding :在注意解码器中使用输入喂养方法(Luong等,2015)(默认:true)--use_dropout :在RNN单元格输出中使用辍学(默认:true)--dropout_rate :单元输出的辍学概率(0.0:no dropout)(默认:0.3)训练参数
--learning_rate :模型中每一层的隐藏单元数(默认:0.0002)--max_gradient_norm :此规范的剪辑梯度(默认1.0)--batch_size :批次大小--max_epochs :最大训练时期--max_load_batches :一次预取的最大批次数。--max_seq_length :最大序列长度--display_freq :显示培训状态每次迭代--save_freq :保存模型检查点,每一次迭代--valid_freq :评估模型每一个迭代:有效_data需要--optimizer :用于培训的优化器:( Adadelta,Adam,RMSProp)(默认:Adam)--model_dir :保存模型检查点的路径--model_name :用于模型检查点的文件名--shuffle_each_epoch :每个时期的洗牌培训数据集(默认:true)--sort_by_length :按其目标序列长度对预取的minibatches进行排序(默认:true)解码参数
--beam_width :BeamSearch中使用的光束宽度(默认:1)--decode_batch_size :用于解码的批处理大小--max_decode_step :解码中的最大时间步长限制(默认:500)--write_n_best :Write BeamSearch n-best列表(n = beam_width)(默认值:false)--decode_input :输入文件路径去解码--decode_output :解码输出的输出文件路径运行时参数
--allow_soft_placement :允许设备软位置--log_device_placement :OPS在设备上的日志放置该实施基于以下项目:
有关任何评论和反馈,请给我发送电子邮件至[email protected]或在此处打开问题。