tensorflow实现单层感知机对MNIST分类

版权声明:本文为博主原创文章,转载请注明出处 https://blog.csdn.net/shuzfan/article/details/78535758

所有代码数据可在百度云下载:

链接: https://pan.baidu.com/s/1c31hKLM 密码: 4tpm

所有涉及tensorflow API用法的,均可查看https://tensorflow.google.cn/api_docs/

下面的代码实现了一个单层感知机(Single Layer Perceptron) y=softmax(wx+b),来处理MNIST手写数字识别问题。

# mnist_slp.py
# -*- coding: utf-8 -*-
import tensorflow as tf
from input_data import read_data_sets
import os

# don't show INFO 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

# read mnist
mnist = read_data_sets('MNIST_data', one_hot=True)


# single layer perceptron: y = wx + b
# input
x = tf.placeholder(tf.float32, [None, 784])

# weights
W = tf.Variable(tf.random_normal([784,10], stddev=0.1))

# bias
b = tf.Variable(tf.zeros([10]))

# softmax 
y = tf.nn.softmax(tf.matmul(x,W) + b)

# output
y_ = tf.placeholder(tf.float32, [None, 10])

# cross_entropy loss
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

# optimization with gradient descend, the learning rate is set as 0.02
train_step = tf.train.GradientDescentOptimizer(0.02).minimize(cross_entropy)

# initalize all variables
init = tf.global_variables_initializer()

# start a new session
sess = tf.Session()
sess.run(init)

m_saver = tf.train.Saver()

# 2000 iterations
for i in range(2000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    if i % 100 == 0:
        m_saver.save(sess, './model/mnist_slp', global_step=i)

# computer the accuracy
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

# close the session
sess.close()

上面实现的单层感知机的准确率大概为0.891。

接下来解析一下上述代码中出现的一些tensorflow函数的用法。

tf.placeholder

为一个张量插入占位符,面向一些长期存在的需要被填充的张量,可以提高内存利用效率

# x = tf.placeholder(tf.float32, [None, 784])
placeholder(
    dtype,
    shape=None,
    name=None
)
  • dtype:被填充的张量的数据类型
  • shape:被填充张量的形状(可选参数)。如果没有指定具体值(比如设定为None),则表明可以填充任意形状张量
  • name:为该操作提供一个名字(可选参数)

注意事项:直接对placeholder返回的张量求值会产生错误。 它的值必须在Session.run(), Tensor.eval() 或 Operation.run() 中使用feed_dict来填充。

tf.Variable

声明一个变量,该变量必须被赋初值。

# W = tf.Variable(tf.random_normal([784,10], stddev=0.1))
# b = tf.Variable(tf.zeros([10]))
Variable(
    initializer,
    name=None
)
  • initializer: 初值,既可以是一个张量,也可以是一个返回张量的表达式
  • name:变量名,可以忽略

常见的一些生成张量的方法有:

tf.zeros(shape, dtype=tf.float32, name=None)
tf.zeros_like(tensor, dtype=None, name=None)
tf.constant(value, dtype=None, shape=None, name='Const')
tf.fill(dims, value, name=None)
tf.ones_like(tensor, dtype=None, name=None)
tf.ones(shape, dtype=tf.float32, name=None)

# 生成序列
tf.range(start, limit, delta=1, name='range')
tf.linspace(start, stop, num, name=None)

# 生成随机数
tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)
tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)
tf.random_uniform(shape, minval=0.0, maxval=1.0, dtype=tf.float32, seed=None, name=None)
tf.random_shuffle(value, seed=None, name=None)

tf.reduce_sum

类似的还有tf.reduce_mean,f.reduce_min,f.reduce_max,f.reduce_prod等

以tf.reduce_sum为例,计算张量某个维度上所有元素的和。

# cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
reduce_sum(
    input_tensor,
    axis=None,
    keep_dims=False,
    name=None,
    reduction_indices=None
)
  • input_tensor:输入张量
  • axis:表示在哪个维度进行sum操作,默认为None,表示同时处理所有维度
  • keep_dims:表示是否保留原始数据的维度,默认False,即执行完后数据少一个维度
  • name:为该操作提供一个名字(可选参数)
  • reduction_indices:同axis效果一样,以后会弃用

示例如下:

x = tf.constant([[1, 1, 1], [1, 1, 1]])
tf.reduce_sum(x)  # 6
tf.reduce_sum(x, 0)  # [2, 2, 2]
tf.reduce_sum(x, 1)  # [3, 3]
tf.reduce_sum(x, 1, keep_dims=True)  # [[3], [3]]
tf.reduce_sum(x, [0, 1])  # 6

tf.argmax

类似的还有tf.argmin

tf.argmax用于返回最大值的索引

# correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
argmax(
    input,
    axis=None,
    name=None,
    dimension=None,
    output_type=tf.int64
)

猜你喜欢

转载自blog.csdn.net/shuzfan/article/details/78535758