案例:识别Mnist手写数字
import tensorflow as tf
import numpy as np
class data_fetch():
def __init__(self):
(self.train_data, self.train_label), (self.test_data, self.test_label) = tf.keras.datasets.mnist.load_data()
self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)
self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)
self.train_label = self.train_label.astype(np.int32)
self.test_label = self.test_label.astype(np.int32)
self.train_nums, self.test_nums = self.train_data.shape[0], self.test_data.shape[0]
def fetch(self, bach_size):
index = np.random.randint(0, self.train_nums, bach_size)
self.train_data[index, :], self.train_label[index]
return self.train_data[index, :], self.train_label[index]
class Mmodel(tf.keras.Model):
def __init__(self):
super().__init__()
self.conv1 = tf.keras.layers.Conv2D(filters=32, kernel_size=[5, 5], padding='same', activation=tf.nn.relu)
self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
self.conv2 = tf.keras.layers.Conv2D(filters=64, kernel_size=[5, 5], padding='same', activation=tf.nn.relu)
self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
self.flatten = tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,))
self.dense1 = tf.keras.layers.Dense(units=500, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(units=10)
def call(self, input):
x = self.conv1(input)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.dense1(x)
x = self.dense2(x)
output = tf.nn.softmax(x)
return output
def train(self):
model = Mmodel()
load_ft = data_fetch()
bach_size = 30
learning_rate = 0.002
epochs = 2
nums = int(load_ft.train_nums // bach_size * epochs)
opitimizer = tf.optimizers.SGD(learning_rate=learning_rate)
for i in range(nums):
x, y = load_ft.fetch(bach_size)
with tf.GradientTape() as tape:
y_pre = model(x)
loss = tf.reduce_mean(tf.keras.metrics.sparse_categorical_crossentropy(y_true=y, y_pred=y_pre))
print("损失",loss)
grade = tape.gradient(loss, model.variables)
opitimizer.apply_gradients(grads_and_vars=zip(grade, model.variables))
per=model.predict(load_ft.test_data)
scc=tf.metrics.SparseCategoricalAccuracy()
scc.update_state(y_true=load_ft.test_label,y_pred=per)
print("正确率{}".format(scc.result()))
if __name__ == '__main__':
tt=Mmodel()
tt.train()
正确率0.9387999773025513