TensorFlow(四)——MNIST分类之CNN

import input_data
import tensorflow as tf
import numpy as np

mnist = input_data.read_data_sets('data/', one_hot=True)

trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

#-1 不考虑数量,28×28像素,1为通道
trX = trX.reshape(-1, 28, 28, 1)
teX = teX.reshape(-1, 28, 28, 1)

X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])

#3卷基层,3池化层,1全连接,1输出层
def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))
w = init_weights([3, 3, 1, 32])
w2 = init_weights([3, 3, 32, 64])
w3 = init_weights([3, 3, 64, 128])
w4 = init_weights([128 * 4 *4, 625])
w_o = init_weights([625, 10])

#定义模型函数
#X:输入数据,w:权重,p_keep_conv,p_keep_hidden:dropout保留的神经元比例
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
    #第一层卷集和池化,然后dropout
    l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
    l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    l1 = tf.nn.dropout(l1, p_keep_conv)
    
    l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
    l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    l2 = tf.nn.dropout(l2, p_keep_conv)
    
    l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
    l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]])
    l3 = tf.nn.dropout(l3, p_keep_conv)
    
    #全连接层
    l4 = tf.nn.relu(tf.matmul(l3, w4))
    l4 = tf.nn.dropout(l4, p_keep_hidden)
    
    #输出层
    pyx = tf.matmul(l4, w_o)
    return pyx

p_keep_conv = tf.placeholder('float')
p_keep_hidden = tf.placeholder('float')
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden)

#定义损失函数
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)

batch_size = 128
test_size = 256

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    
    for i in range(100):
        training_batch = zip(range(0, len(trX), batch_size),
                            range(batch_size, len(trX)+1, batch_size))
        for start, end in training_batch:
            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
                                         p_keep_conv: 0.8, p_keep_hidden: 0.5})
        test_indices = np.arange(len(teX))
        np.random.shuffle(test_indices)
        test_indices = test_indices[0:test_size]
        
        print (i, np.mean(np.argmax(teY[test_indices], axis=1) == 
                         sess.run(predict_op, feed_dict={X: teX[test_indices],
                                                        p_keep_conv: 1.0,
                                                        p_keep_hidden: 1.0})))

结果:

0 0.93359375
1 0.9765625
2 0.9765625
3 0.9921875
4 0.984375
5 0.9921875
6 1.0
7 0.98828125
8 0.984375
9 0.98046875
10 0.99609375
11 0.98828125
12 1.0
13 0.98828125
14 0.9921875
15 0.99609375
16 0.9921875
17 0.99609375
18 0.99609375
19 0.98828125
20 0.9921875
21 1.0
22 0.9921875
23 0.9921875
24 0.98828125
25 0.99609375
26 0.99609375
27 0.98046875
28 0.98828125
29 0.9921875
30 1.0
31 1.0
32 0.99609375
33 0.98828125
34 0.984375
35 1.0
36 0.984375
37 0.99609375
38 1.0
39 0.99609375
40 0.9921875
41 0.97265625
42 1.0
43 0.9921875
44 0.99609375
45 0.984375
46 1.0
47 1.0
48 0.98828125
49 0.9765625
50 0.9921875
51 1.0
52 0.98828125
53 0.98828125
54 0.9921875
55 0.99609375
56 1.0
57 0.99609375
58 1.0
59 0.9921875
60 0.99609375
61 0.98828125
62 0.9921875
63 0.9921875
64 0.9921875
65 0.98828125
66 0.99609375
67 0.99609375
68 0.984375
69 1.0
70 0.98828125
71 0.98828125
72 0.99609375
73 1.0
74 1.0
75 0.99609375
76 0.98828125
77 0.9921875
78 0.98828125
79 0.9921875
80 1.0
81 0.99609375
82 0.99609375
83 0.98828125
84 0.984375
85 0.98828125
86 0.99609375
87 0.99609375
88 0.9921875
89 0.99609375
90 0.98828125
91 0.99609375
92 1.0
93 0.9921875
94 0.98046875
95 1.0
96 0.99609375
97 0.984375
98 0.9921875
99 0.99609375

猜你喜欢

转载自blog.csdn.net/MRxjh/article/details/82660843