TF-seq2seq 使用TensorFlow实现序列到序列(seq2seq)学习

网友投稿 1287 2022-11-03

TF-seq2seq 使用TensorFlow实现序列到序列(seq2seq)学习

TF-seq2seq 使用TensorFlow实现序列到序列(seq2seq)学习

TF-seq2seq

Sequence to sequence (seq2seq) learning Using TensorFlow.

The core building blocks are RNN Encoder-Decoder architectures and Attention mechanism.

The package was largely implemented using the latest (1.2) tf.contrib.seq2seq modules

AttentionWrapperDecoderBasicDecoderBeamSearchDecoder

The package supports

Multi-layer GRU/LSTMResidual connectionDropoutAttention and input_feedingBeamsearch decodingWrite n-best list

Dependencies

NumPy >= 1.11.1Tensorflow >= 1.2

History

June 5, 2017: Major updateJune 6, 2017: Supports batch beamsearch decodingJune 11, 2017: Separted training / decodingJune 22, 2017: Supports tf.1.2 (contrib.rnn -> python.ops.rnn_cell)

Usage Instructions

Data Preparation

To preprocess raw parallel data of sample_data.src and sample_data.trg, simply run

cd data/./preprocess.sh src trg sample_data ${max_seq_len}

Running the above code performs widely used preprocessing steps for Machine Translation (MT).

Normalizing punctuationTokenizingBytepair encoding (# merge = 30000) (Sennrich et al., 2016)Cleaning sequences of length over ${max_seq_len}ShufflingBuilding dictionaries

Training

To train a seq2seq model,

$ python train.py --cell_type 'lstm' \ --attention_type 'luong' \ --hidden_units 1024 \ --depth 2 \ --embedding_size 500 \ --num_encoder_symbols 30000 \ --num_decoder_symbols 30000 ...

Decoding

To run the trained model for decoding,

$ python decode.py --beam_width 5 \ --decode_batch_size 30 \ --model_path $PATH_TO_A_MODEL_CHECKPOINT (e.g. model/translate.ckpt-100) \ --max_decode_step 300 \ --write_n_best False --decode_input $PATH_TO_DECODE_INPUT --decode_output $PATH_TO_DECODE_OUTPUT

If --beam_width=1, greedy decoding is performed at each time-step.

Arguments

Data params

--source_vocabulary : Path to source vocabulary--target_vocabulary : Path to target vocabulary--source_train_data : Path to source training data--target_train_data : Path to target training data--source_valid_data : Path to source validation data--target_valid_data : Path to target validation data

Network params

--cell_type : RNN cell to use for encoder and decoder (default: lstm)--attention_type : Attention mechanism (bahdanau, luong), (default: bahdanau)--depth : Number of hidden units for each layer in the model (default: 2)--embedding_size : Embedding dimensions of encoder and decoder inputs (default: 500)--num_encoder_symbols : Source vocabulary size to use (default: 30000)--num_decoder_symbols : Target vocabulary size to use (default: 30000)--use_residual : Use residual connection between layers (default: True)--attn_input_feeding : Use input feeding method in attentional decoder (Luong et al., 2015) (default: True)--use_dropout : Use dropout in rnn cell output (default: True)--dropout_rate : Dropout probability for cell outputs (0.0: no dropout) (default: 0.3)

Training params

--learning_rate : Number of hidden units for each layer in the model (default: 0.0002)--max_gradient_norm : Clip gradients to this norm (default 1.0)--batch_size : Batch size--max_epochs : Maximum training epochs--max_load_batches : Maximum number of batches to prefetch at one time.--max_seq_length : Maximum sequence length--display_freq : Display training status every this iteration--save_freq : Save model checkpoint every this iteration--valid_freq : Evaluate the model every this iteration: valid_data needed--optimizer : Optimizer for training: (adadelta, adam, rmsprop) (default: adam)--model_dir : Path to save model checkpoints--model_name : File name used for model checkpoints--shuffle_each_epoch : Shuffle training dataset for each epoch (default: True)--sort_by_length : Sort pre-fetched minibatches by their target sequence lengths (default: True)

Decoding params

--beam_width : Beam width used in beamsearch (default: 1)--decode_batch_size : Batch size used in decoding--max_decode_step : Maximum time step limit in decoding (default: 500)--write_n_best : Write beamsearch n-best list (n=beam_width) (default: False)--decode_input : Input file path to decode--decode_output : Output file path of decoding output

Runtime params

--allow_soft_placement : Allow device soft placement--log_device_placement : Log placement of ops on devices

Acknowledgements

The implementation is based on following projects:

nematus: Theano implementation of Neural Machine Translation. Major reference of this projectsubword-nmt: Included subword-unit scripts to preprocess input datamoses: Included preprocessing scripts to preprocess input datatf.seq2seq_legacy Legacy Tensorflow seq2seq tutorialtf_tutorial_plus: Nice tutorials for tf.contrib.seq2seq API

For any comments and feedbacks, please email me at pjh0308@gmail.com or open an issue here.

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:使用Spring初始化加载InitializingBean()方法
下一篇:一个用于编译或运行代码生成的实用程序,其目标是Android的Dalvik VM
相关文章

 发表评论

暂时没有评论,来抢沙发吧~