Los bloques de construcción principales son las arquitecturas de decodificador de codificadores RNN y el mecanismo de atención.
El paquete se implementó en gran medida utilizando los últimos módulos (1.2) tf.contrib.seq2seq
El paquete es compatible
Para preprocesar datos paralelos sin procesar de sample_data.src y sample_data.trg , simplemente ejecute
cd data /
. / preprocess . sh src trg sample_data $ { max_seq_len }La ejecución del código anterior realiza pasos de preprocesamiento ampliamente utilizados para la traducción automática (MT).
Para entrenar un modelo SEQ2SEQ,
$ python train . py -- cell_type 'lstm'
-- attention_type 'luong'
-- hidden_units 1024
-- depth 2
-- embedding_size 500
-- num_encoder_symbols 30000
-- num_decoder_symbols 30000 ...Para ejecutar el modelo entrenado para decodificar,
$ python decode . py -- beam_width 5
-- decode_batch_size 30
-- model_path $PATH_TO_A_MODEL_CHECKPOINT ( e . g . model / translate . ckpt - 100 )
-- max_decode_step 300
-- write_n_best False
-- decode_input $PATH_TO_DECODE_INPUT
-- decode_output $PATH_TO_DECODE_OUTPUT
Si --beam_width=1 , la decodificación codiciosa se realiza en cada paso de tiempo.
Parámetros de datos
--source_vocabulary : Vocabulario de ruta al origen--target_vocabulary : Vocabulario de ruta a Target--source_train_data : datos de entrenamiento de ruta a fuente--target_train_data : datos de entrenamiento de ruta a objetivo--source_valid_data : datos de validación de la ruta a la fuente--target_valid_data : datos de validación de ruta a destinoParámetros de red
--cell_type : Cell RNN para usar para codificador y decodificador (predeterminado: LSTM)--attention_type : Mecanismo de atención (Bahdanau, Luong), (predeterminado: Bahdanau)--depth : número de unidades ocultas para cada capa en el modelo (predeterminado: 2)--embedding_size : incrustación de dimensiones de las entradas del codificador y del decodificador (predeterminado: 500)--num_encoder_symbols : tamaño de vocabulario de origen para usar (predeterminado: 30000)--num_decoder_symbols : tamaño de vocabulario objetivo para usar (predeterminado: 30000)--use_residual : use conexión residual entre capas (predeterminada: verdadero)--attn_input_feeding : Use el método de alimentación de entrada en el decodificador de atención (Luong et al., 2015) (predeterminado: Verdadero)--use_dropout : use la salida en la salida de la celda RNN (predeterminado: verdadero)--dropout_rate : probabilidad de abandono para salidas de celdas (0.0: sin abandono) (predeterminado: 0.3)Parámetros de entrenamiento
--learning_rate : número de unidades ocultas para cada capa en el modelo (predeterminado: 0.0002)--max_gradient_norm : clip gradientes a esta norma (predeterminado 1.0)--batch_size : tamaño por lotes--max_epochs : épocas de entrenamiento máximas--max_load_batches : número máximo de lotes para complacer a la vez.--max_seq_length : longitud de secuencia máxima--display_freq : muestre el estado de capacitación cada iteración--save_freq : Guardar el punto de control del modelo cada esta iteración--valid_freq : evalúe el modelo cada iteración: válido_data necesario--optimizer : optimizador para el entrenamiento: (Adadelta, Adam, RMSProp) (predeterminado: Adam)--model_dir : ruta para guardar los puntos de control del modelo--model_name : nombre de archivo utilizado para los puntos de control del modelo--shuffle_each_epoch : conjunto de datos de entrenamiento de shuffle para cada época (predeterminado: verdadero)--sort_by_length : ordene minibatchs prefabistados por sus longitudes de secuencia de destino (predeterminado: verdadero)Parámetros de decodificación
--beam_width : ancho del haz utilizado en BeamSearch (predeterminado: 1)--decode_batch_size : tamaño por lotes utilizado en la decodificación--max_decode_step : límite de paso de tiempo máximo en la decodificación (predeterminado: 500)--write_n_best : escribir beamsearch n-best list (n = beam_width) (predeterminado: falso)--decode_input : ruta de archivo de entrada para decodificar--decode_output : ruta de archivo de salida de la salida de decodificaciónParámetros de tiempo de ejecución
--allow_soft_placement : Permitir la colocación suave del dispositivo--log_device_placement : colocación de registro de OPS en dispositivos La implementación se basa en los siguientes proyectos:
Para obtener comentarios y comentarios, envíeme un correo electrónico a [email protected] o abra un problema aquí.