tensorflow 2.0 实战-Minst

版权声明:版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/z_feng12489/article/details/89229968

数据载入

import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

(x, y), (x_val, y_val) = datasets.mnist.load_data()    # train and test
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.   # normalize (broadcasting)
y = tf.convert_to_tensor(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
print(x.shape, y.shape)   # (60000, 28, 28) (60000, 10)
train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
train_dataset = train_dataset.batch(200)   # setting batch

# for step, (x, y) in enumerate(train_dataset):
#     print (step, x.shape, y.shape)
0 (200, 28, 28) (200, 10)
1 (200, 28, 28) (200, 10)
2 (200, 28, 28) (200, 10)
3 (200, 28, 28) (200, 10)
4 (200, 28, 28) (200, 10)
5 (200, 28, 28) (200, 10)
6 (200, 28, 28) (200, 10)
......
295 (200, 28, 28) (200, 10)
296 (200, 28, 28) (200, 10)
297 (200, 28, 28) (200, 10)
298 (200, 28, 28) (200, 10)
299 (200, 28, 28) (200, 10)

搭建全连接网络

o u t = r e l u { r e l u { r e l u { x @ w 1 + b 1 } @ w 2 + b 2 } @ w 3 + b 3 } out = relu\{relu\{relu\{x@w_1+b_1\}@w_2+b_2\}@w_3+b_3\}

model = keras.Sequential([
    layers.Dense(512, activation='relu'),
    layers.Dense(256, activation='relu'),
    layers.Dense(10)])

optimizer = optimizers.SGD(learning_rate=0.001)

计算输出和 Loss

with tf.GradientTape() as tape:
# [b, 28, 28] => [b, 784]
	x = tf.reshape(x, (-1, 28 * 28))
	# Step1. compute output
	# [b, 784] => [b, 10]
	out = model(x)
	# Step2. compute loss
	loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]

计算梯度与回传

# Step3. optimize and update w1, w2, w3, b1, b2, b3
grads = tape.gradient(loss, model.trainable_variables)
# w' = w - lr * grad
optimizer.apply_gradients(zip(grads, model.trainable_variables))

Loop

def train_epoch(epoch):
    # Step4.loop
    for step, (x, y) in enumerate(train_dataset):

        with tf.GradientTape() as tape:
            # [b, 28, 28] => [b, 784]
            x = tf.reshape(x, (-1, 28 * 28))
            # Step1. compute output
            # [b, 784] => [b, 10]
            out = model(x)
            # Step2. compute loss
            loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]

        # Step3. optimize and update w1, w2, w3, b1, b2, b3
        grads = tape.gradient(loss, model.trainable_variables)
        # w' = w - lr * grad
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if step % 100 == 0:
            print(epoch, step, 'loss:', loss.numpy())

完整代码

import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

(x, y), (x_val, y_val) = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
print(x.shape, y.shape)
train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
train_dataset = train_dataset.batch(200)

model = keras.Sequential([
    layers.Dense(512, activation='relu'),
    layers.Dense(256, activation='relu'),
    layers.Dense(10)])

optimizer = optimizers.SGD(learning_rate=0.001)


def train_epoch(epoch):
    # Step4.loop
    for step, (x, y) in enumerate(train_dataset):

        with tf.GradientTape() as tape:
            # [b, 28, 28] => [b, 784]
            x = tf.reshape(x, (-1, 28 * 28))
            # Step1. compute output
            # [b, 784] => [b, 10]
            out = model(x)
            # Step2. compute loss
            loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]

        # Step3. optimize and update w1, w2, w3, b1, b2, b3
        grads = tape.gradient(loss, model.trainable_variables)
        # w' = w - lr * grad
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if step % 100 == 0:
            print(epoch, step, 'loss:', loss.numpy())


def train():
    for epoch in range(30):
        train_epoch(epoch)


if __name__ == '__main__':
    train()

猜你喜欢

转载自blog.csdn.net/z_feng12489/article/details/89229968
今日推荐