LeNet5训练cifar10

 1 import os
 2 
 3 import tensorflow as tf
 4 from tensorflow.keras import datasets, layers, Sequential, optimizers
 5 
 6 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 7 tf.random.set_seed(2345)
 8 
 9 conv_layers = [
10     # unit1
11     layers.Conv2D(6, kernel_size=[5, 5], strides=1, padding='valid', activation=tf.nn.sigmoid),
12     layers.MaxPool2D(pool_size=[2, 2], padding='same', strides=2),
13 
14     # unit2
15     layers.Conv2D(16, kernel_size=[5, 5], strides=1, padding='valid', activation=tf.nn.sigmoid),
16     layers.MaxPool2D(pool_size=[2, 2], padding='same', strides=2),
17 ]
18 
19 
20 def preprocess(x, y):
21     x = tf.cast(x, dtype=tf.float32) / 255
22     # x = tf.reshape(x, [32, 32])
23     y = tf.cast(y, dtype=tf.int32)
24     # y = tf.one_hot(y, depth=10)
25     return x, y
26 
27 
28 def main():
29     (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
30     print("x_train.shape:", x_train.shape, "y_train.shape:", y_train.shape)
31     print("x_test.shape:", x_test.shape, 'y_test.shape:', y_test.shape)
32     y_test = tf.squeeze(y_test, axis=1)
33     y_train = tf.squeeze(y_train, axis=1)
34     train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
35     train_db = train_db.shuffle(1000).map(preprocess).batch(256)
36 
37     test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
38     test_db = test_db.map(preprocess).batch(256)
39 
40     sample_train = next(iter(train_db))
41     sample_test = next(iter(test_db))
42     print("sample_train[0].shape:", sample_train[0].shape, "sample_train[1].shape:", sample_train[1].shape)
43     print("sample_test[0].shape:", sample_test[0].shape, "sample_test[1].shape:", sample_test[1].shape)
44 
45     conv_net = Sequential(conv_layers)
46     fc_net = Sequential([
47         layers.Dense(120, activation=tf.nn.tanh),
48         layers.Dense(84, activation=tf.nn.tanh),
49         layers.Dense(10, activation=None),
50     ])
51     conv_net.build(input_shape=[None, 32, 32, 3])
52     fc_net.build(input_shape=[None, 400])
53 
54     optimizer = optimizers.Adam(lr=1e-4)
55     variables = conv_net.trainable_variables + fc_net.trainable_variables
56     for epoch in range(2000):
57         for step, (x, y) in enumerate(train_db):
58             with tf.GradientTape() as tape:
59                 # [b, 32, 32, 3]->[b, 5,5,16]
60                 out = conv_net(x)
61                 # [b,5,5,16]->[b,400]
62                 out = tf.reshape(out, [-1, 400])
63                 # [b, 16, 16, 5]->[b, 10]
64                 logits = fc_net(out)
65                 y_onehot = tf.one_hot(y, depth=10)
66                 loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
67                 loss = tf.reduce_mean(loss)
68             grads = tape.gradient(loss, variables)
69             optimizer.apply_gradients(zip(grads, variables))
70 
71             if step % 100 == 0:
72                 print(epoch, step, 'loss:', float(loss))
73         total_num = 0
74         total_correct = 0
75         for x, y in test_db:
76             out = conv_net(x)
77             out = tf.reshape(out, [-1, 400])
78             logits = fc_net(out)
79             prob = tf.nn.softmax(logits, axis=1)
80             pred = tf.argmax(prob, axis=1)
81             pred = tf.cast(pred, dtype=tf.int32)
82             # y = tf.cast(y, dtype=tf.int32)
83             correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
84             correct = tf.reduce_sum(correct)
85 
86             total_num += x.shape[0]
87             total_correct += int(correct)
88         acc = total_correct / total_num
89         print(epoch, 'acc:', acc)
90 
91 
92 if __name__ == '__main__':
93     main()

 利用LeNet5训练cifar10数据集,跑了2000个epoch,准确率只有0.63,不是很理性,主要是LeNet5网络结构过于简单

猜你喜欢

转载自www.cnblogs.com/bsyu/p/12173391.html
今日推荐