Seq2Seq模型应用案例之ScheduledEmbeddingTrainingHelper

版权声明:王家林大咖2018年新书《SPARK大数据商业实战三部曲》清华大学出版,微信公众号:从零起步学习人工智能 https://blog.csdn.net/duan_zhihua/article/details/87301397

Seq2Seq模型应用案例之ScheduledEmbeddingTrainingHelper:

        Tensorflow最新的Seq2Seq案例请参考官网:https://github.com/tensorflow/nmt 这里不再赘述。

       在之前的博客中https://blog.csdn.net/duan_zhihua/article/details/87114665提及模型训练与模型预测的差异性,Tensorflow提供了ScheduledEmbedding的机制,训练时候解码器加入了抽样概率,按epoch的进度逐渐提高抽样概率:概率抽样为0的时候ScheduledEmbedidngTrainingHelper相当于TrainingHelper,概率抽样为1的时候ScheduledEmbedidngTrainingHelper相当于GreedyEmbeddingHelper,在0到1之间按照概率抽样目标词做预测。ScheduledEmbeddingTrainingHelper比没有实施计划采样的效果较好。 

# 0.0 ≤ sampling_probability ≤ 1.0
# 0.0: no sampling => `ScheduledEmbedidngTrainingHelper` is equivalent to `TrainingHelper` 可能过拟合!
# 1.0: always sampling => `ScheduledEmbedidngTrainingHelper` is equivalent to `GreedyEmbeddingHelper`
# Inceasing sampling over steps => Curriculum Learning

Seq2SeqModel代码

 https://github.com/duanzhihua/tf_tutorial_plus/blob/master/RNN_seq2seq/contrib_seq2seq/02_ScheduledEmbeddingTrainingHelper.ipynb

class Seq2SeqModel(object):
    def __init__(self, config, mode='training'):
        assert mode in ['training', 'evaluation', 'inference']
        self.mode = mode

        # Model
        self.hidden_size = config.hidden_size
        self.enc_emb_size = config.enc_emb_size
        self.dec_emb_size = config.dec_emb_size
        self.cell = config.cell
        
        # Training
        self.optimizer = config.optimizer
        self.n_epoch = config.n_epoch
        self.learning_rate = config.learning_rate
        
        # Sampling Probability
        self.sampling_probability_list = config.sampling_probability_list
        
        # Checkpoint path
        self.ckpt_dir = config.ckpt_dir
        
    def add_placeholders(self):
        self.enc_inputs = tf.placeholder(
            tf.int32,
            shape=[None, enc_sentence_length],
            name='input_sentences')

        self.enc_sequence_length = tf.placeholder(
            tf.int32,
            shape=[None,],
            name='input_sequence_length')
        
        if self.mode == 'training':
            self.dec_inputs = tf.placeholder(
                tf.int32,
                shape=[None, dec_sentence_length+1],
                name='target_sentences')

            self.dec_sequence_length = tf.placeholder(
                tf.int32,
                shape=[None,],
                name='target_sequence_length')

            self.sampling_probability = tf.placeholder(
                tf.float32,
                shape=[],
                name='sampling_probability')
            # 0.0 ≤ sampling_probability ≤ 1.0
            # 0.0: no sampling => `ScheduledEmbedidngTrainingHelper` is equivalent to `TrainingHelper`
            # 1.0: always sampling => `ScheduledEmbedidngTrainingHelper` is equivalent to `GreedyEmbeddingHelper`
            # Inceasing sampling over steps => Curriculum Learning
            
    def add_encoder(self):
        with tf.variable_scope('Encoder') as scope:
            with tf.device('/cpu:0'):
                self.enc_Wemb = tf.get_variable('embedding',
                    initializer=tf.random_uniform([enc_vocab_size+1, self.enc_emb_size]),
                    dtype=tf.float32)

            # [Batch_size x enc_sent_len x embedding_size]
            enc_emb_inputs = tf.nn.embedding_lookup(
                self.enc_Wemb, self.enc_inputs, name='emb_inputs')
            enc_cell = self.cell(self.hidden_size)

            # enc_outputs: [batch_size x enc_sent_len x embedding_size]
            # enc_last_state: [batch_size x embedding_size]
            enc_outputs, self.enc_last_state = tf.nn.dynamic_rnn(
                cell=enc_cell,
                inputs=enc_emb_inputs,
                sequence_length=self.enc_sequence_length,
                time_major=False,
                dtype=tf.float32)
            
    def add_decoder(self):
        with tf.variable_scope('Decoder') as scope:
            with tf.device('/cpu:0'):
                self.dec_Wemb = tf.get_variable('embedding',
                    initializer=tf.random_uniform([dec_vocab_size+2, self.dec_emb_size]),
                    dtype=tf.float32)

            dec_cell = self.cell(self.hidden_size)

            # output projection (replacing `OutputProjectionWrapper`)
            output_layer = Dense(dec_vocab_size+2, name='output_projection')
            
            if self.mode == 'training':

                # maximum unrollings in current batch = max(dec_sent_len) + 1(GO symbol)
                max_dec_len = tf.reduce_max(self.dec_sequence_length+1, name='max_dec_len')

                dec_emb_inputs = tf.nn.embedding_lookup(
                    self.dec_Wemb, self.dec_inputs, name='emb_inputs')

                training_helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
                    inputs=dec_emb_inputs,
                    sequence_length=self.dec_sequence_length+1,
                    embedding=self.dec_Wemb,
                    sampling_probability=self.sampling_probability,
                    time_major=False,
                    name='training_helper')                

                training_decoder = tf.contrib.seq2seq.BasicDecoder(
                    cell=dec_cell,
                    helper=training_helper,
                    initial_state=self.enc_last_state,
                    output_layer=output_layer) 

                train_dec_outputs, train_dec_last_state = tf.contrib.seq2seq.dynamic_decode(
                    training_decoder,
                    output_time_major=False,
                    impute_finished=True,
                    maximum_iterations=max_dec_len)
                
                # dec_outputs: collections.namedtuple(rnn_outputs, sample_id)
                # dec_outputs.rnn_output: [batch_size x max(dec_sequence_length) x dec_vocab_size+2], tf.float32
                # dec_outputs.sample_id [batch_size], tf.int32
                
                # logits: [batch_size x max_dec_len x dec_vocab_size+2]
                logits = tf.identity(train_dec_outputs.rnn_output, name='logits')
                
                # targets: [batch_size x max_dec_len x dec_vocab_size+2]
                targets = tf.slice(self.dec_inputs, [0, 0], [-1, max_dec_len], 'targets')
                
                # masks: [batch_size x max_dec_len]
                # => ignore outputs after `dec_senquence_length+1` when calculating loss
                masks = tf.sequence_mask(self.dec_sequence_length+1, max_dec_len, dtype=tf.float32, name='masks')
                
                # Control loss dimensions with `average_across_timesteps` and `average_across_batch`
                # internal: `tf.nn.sparse_softmax_cross_entropy_with_logits`
                self.batch_loss = tf.contrib.seq2seq.sequence_loss(
                    logits=logits,
                    targets=targets,
                    weights=masks,
                    name='batch_loss')
                
                # prediction sample for validation
                # some sample_id are overwritten with '-1's
                self.valid_predictions = tf.argmax(logits, axis=2, name='valid_predictions')
                
                # List of training variables
                # self.training_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            
            elif self.mode == 'inference':

                batch_size = tf.shape(self.enc_inputs)[0:1]
                start_tokens = tf.zeros(batch_size, dtype=tf.int32)

                inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                    embedding=self.dec_Wemb,
                    start_tokens=start_tokens,
                    end_token=1)
                
                inference_decoder = tf.contrib.seq2seq.BasicDecoder(
                    cell=dec_cell,
                    helper=inference_helper,
                    initial_state=self.enc_last_state,
                    output_layer=output_layer)
                
                infer_dec_outputs, infer_dec_last_state = tf.contrib.seq2seq.dynamic_decode(
                    inference_decoder,
                    output_time_major=False,
                    impute_finished=True,
                    maximum_iterations=dec_sentence_length)
                
                # [batch_size x dec_sentence_length], tf.int32
                self.predictions = tf.identity(infer_dec_outputs.sample_id, name='predictions')
                # equivalent to tf.argmax(infer_dec_outputs.rnn_output, axis=2, name='predictions')

                # List of training variables
                # self.training_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        
    def add_training_op(self):
        self.training_op = self.optimizer(self.learning_rate, name='training_op').minimize(self.batch_loss)
        
    def save(self, sess, var_list=None, save_path=None):
        print(f'Saving model at {save_path}')
        if hasattr(self, 'training_variables'):
            var_list = self.training_variables
        saver = tf.train.Saver(var_list)
        saver.save(sess, save_path, write_meta_graph=False)
        
    def restore(self, sess, var_list=None, ckpt_path=None):
        if hasattr(self, 'training_variables'):
            var_list = self.training_variables
        self.restorer = tf.train.Saver(var_list)
        self.restorer.restore(sess, ckpt_path)
        print('Restore Finished!')
        
    def summary(self):
        summary_writer = tf.summary.FileWriter(
            logdir=self.ckpt_dir,
            graph=tf.get_default_graph())
        
    def build(self):
        self.add_placeholders()
        self.add_encoder()
        self.add_decoder()
        
    def train(self, sess, data, from_scratch=False, load_ckpt=None, save_path=None):
        
        # Restore Checkpoint
        if from_scratch is False and os.path.isfile(load_ckpt):
            self.restore(sess, load_ckpt)
    
        # Add Optimizer to current graph
        self.add_training_op()
        
        sess.run(tf.global_variables_initializer())
        
        input_batches, target_batches = data
        loss_history = []
        
        for epoch in tqdm(range(self.n_epoch)):

            all_preds = []
            epoch_loss = 0
            for input_batch, target_batch in zip(input_batches, target_batches):
                input_batch_tokens = []
                target_batch_tokens = []
                input_batch_sent_lens = []
                target_batch_sent_lens = []

                for input_sent in input_batch:
                    tokens, sent_len = sent2idx(input_sent)
                    input_batch_tokens.append(tokens)
                    input_batch_sent_lens.append(sent_len)

                for target_sent in target_batch:
                    tokens, sent_len = sent2idx(target_sent,
                                 vocab=dec_vocab,
                                 max_sentence_length=dec_sentence_length,
                                 is_target=True)
                    target_batch_tokens.append(tokens)
                    target_batch_sent_lens.append(sent_len)
       
                # Evaluate 3 ops in the graph
                # => valid_predictions, loss, training_op(optimzier)
                batch_valid_preds, batch_loss, _ = sess.run(
                    [self.valid_predictions, self.batch_loss, self.training_op],
                    feed_dict={
                        self.enc_inputs: input_batch_tokens,
                        self.enc_sequence_length: input_batch_sent_lens,
                        self.dec_inputs: target_batch_tokens,
                        self.dec_sequence_length: target_batch_sent_lens,
                        self.sampling_probability: self.sampling_probability_list[epoch]
                    }
                )
                # loss_history.append(batch_loss)
                epoch_loss += batch_loss
                all_preds.append(batch_valid_preds)
                
            loss_history.append(epoch_loss)
                        
            # Logging every 400 epochs
            if epoch % 400 == 0:
                print('Epoch', epoch)
                print(f'Sampling probability: {self.sampling_probability_list[epoch]:.3f}')
                for input_batch, target_batch, batch_preds in zip(input_batches, target_batches, all_preds):
                    for input_sent, target_sent, pred in zip(input_batch, target_batch, batch_preds):
                        print(f'\tInput: {input_sent}')
                        print(f'\tPrediction:', idx2sent(pred, reverse_vocab=dec_reverse_vocab))
                        print(f'\tTarget: {target_sent}\n')
                print(f'\tepoch loss: {epoch_loss:.2f}\n')
                
        if save_path:
            self.save(sess, save_path=save_path)

        return loss_history
    
    def inference(self, sess, data, load_ckpt):

        self.restore(sess, ckpt_path=load_ckpt)
                
        input_batch, target_batch = data
        
        batch_preds = []
        batch_tokens = []
        batch_sent_lens = []

        for input_sent in input_batch:
            tokens, sent_len = sent2idx(input_sent)
            batch_tokens.append(tokens)
            batch_sent_lens.append(sent_len)
            
        batch_preds = sess.run(
            self.predictions,
            feed_dict={
                self.enc_inputs: batch_tokens,
                self.enc_sequence_length: batch_sent_lens,
            })

        for input_sent, target_sent, pred in zip(input_batch, target_batch, batch_preds):
            print('Input:', input_sent)
            print('Prediction:', idx2sent(pred, reverse_vocab=dec_reverse_vocab))
            print('Target:', target_sent, '\n')

猜你喜欢

转载自blog.csdn.net/duan_zhihua/article/details/87301397
今日推荐