实现一个真实的动态RNN

提示

如果代码中出现了你不懂的接口,请翻看本人博客分类中名为 “tensorflow学习”的类目中,本人肯定一定有的,这里就列取出了主要的两个接口说明,其他如果不懂的,请自己去找哈

tf.nn.dynamic_rnn详解

tf.gather,tf.range的详解

代码

代码来自这里,我对进行了一些接口的提升,改进

#!/usr/bin/env python
# coding: utf-8
from __future__ import print_function
import tensorflow as tf
import random    


class ToySequenceData:
    
    def __init__(self,n_samples=1000,max_seq_len=20,min_seq_len=3,max_value=1000):
        self.data=[]
        self.labels=[]
        self.seqlen=[]
        
        for i in range(n_samples):
            len=random.randint(min_seq_len,max_seq_len)
            self.seqlen.append(len)
            if random.random()<0.5:
                ##创建一个线性的序列
                rand_start=random.randint(0,max_value-len)
                s=[ [i/max_value] for i in range(rand_start,rand_start+len)]
                #将s的长度扩展成到最大长度
                s +=[[0.0] for _ in range(max_seq_len-len)]
                self.data.append(s)
                self.labels.append([1.0,0.0])
            else:
                #生成一个随机的序列
                s=[ [random.randint(0,max_value)/max_value] for _ in range(len)]  
                
                s+=[[0.0] for _ in range(max_seq_len-len)]
                self.data.append(s)
                self.labels.append([0.0,1.0])
        self.batch_id=0
                  
            
    def next(self,batch_size):
        #这个地方设置这个好
        if self.batch_id>=len(self.data):
            self.batch_id=0
        batch_data=(self.data[self.batch_id:min(self.batch_id+batch_size,len(self.data))])
        batch_labels=(self.labels[self.batch_id:min(self.batch_id+batch_size,len(self.labels))])
        batch_seqlen=(self.seqlen[self.batch_id:min(self.batch_id+batch_size,len(self.seqlen))])
        self.batch_id =min(self.batch_id+batch_size,len(self.data))
        return batch_data,batch_labels,batch_seqlen
            


# Parameters
learning_rate = 0.01
training_steps = 10000
batch_size = 128
display_step = 200

# Network Parameters
seq_max_len = 20 # Sequence max length
n_hidden = 64 # hidden layer num of features
n_classes = 2 # linear sequence or not

trainset = ToySequenceData(n_samples=1000, max_seq_len=seq_max_len)
testset = ToySequenceData(n_samples=500, max_seq_len=seq_max_len)

# tf Graph input
X=tf.placeholder("float",[None,seq_max_len,1])
Y=tf.placeholder("float",[None,n_classes])
seqlen=tf.placeholder(tf.int32,[None])


def dynamicRNN(x,seqlen):
    lstm_cell=tf.nn.rnn_cell.LSTMCell(n_hidden)
    outputs, state=tf.nn.dynamic_rnn(lstm_cell,x,sequence_length=seqlen,dtype=tf.float32)
    # outputs在这里的shape是 [batch_size,seq_max_len , cell.n_hidden]
    batch_size=tf.shape(outputs)[0]
    
    #获取每个序列实际长度的最后一步的index
    each_seq_start=tf.range(0,batch_size)*seq_max_len
    index=each_seq_start+(seqlen-1)
    outputs=tf.gather(tf.reshape(outputs,[-1,n_hidden]),index)
    out=tf.layers.dense(outputs,n_classes,use_bias=True)
    return out
    
    
  

pred = dynamicRNN(X, seqlen)
# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()





# Start training
with tf.Session() as sess:

    # Run the initializer
    sess.run(init)

    for step in range(1, training_steps+1):
        batch_x, batch_y, batch_seqlen = trainset.next(batch_size)
        # Run optimization op (backprop)
        sess.run(optimizer, feed_dict={X: batch_x, Y: batch_y,
                                       seqlen: batch_seqlen})
        if step % display_step == 0 or step == 1:
            # Calculate batch accuracy & loss
            acc, loss = sess.run([accuracy, cost], feed_dict={X: batch_x, Y: batch_y,
                                                seqlen: batch_seqlen})
            print("Step " + str(step) + ", Minibatch Loss= " +                   "{:.6f}".format(loss) + ", Training Accuracy= " +                   "{:.5f}".format(acc))

    print("Optimization Finished!")

    # Calculate accuracy
    test_data = testset.data
    test_label = testset.labels
    test_seqlen = testset.seqlen
    print("Testing Accuracy:",         sess.run(accuracy, feed_dict={x: test_data, y: test_label,
                                      seqlen: test_seqlen}))

猜你喜欢

转载自blog.csdn.net/qq_32806793/article/details/85325198