Tensorflow 两层全连接神经网络拟合正弦函数

版权声明:转载请注明出处 https://blog.csdn.net/The_lastest/article/details/82848257

搞了这么就深度学习的你是不是也该来拟合一下 s i n ( x ) sin(x) 了。话说,如果连 s i n ( x ) sin(x) 也不能拟合,那还搞什么深度学习。

1.网络结构

网络结构很简单,如下图所示:

mark

2.制作数据集

制作数据集的思路就是随机生成一个范围的 x x ,然后带入到np.sin(x)计算出真实值 y y .

def gen_data():
    x = np.linspace(-np.pi, np.pi, 100)
    x = np.reshape(x, (len(x), 1))
    y = np.sin(x)
    return  x, y

3.定义网络和可视化


def inference(input_tensor):
    with tf.name_scope('Layer-1'):
        weight = tf.Variable(tf.truncated_normal(shape=[INPUT_NODE, HIDDEN_NODE], stddev=0.1, dtype=tf.float32),
                             name='weight')
        bias = tf.Variable(tf.constant(0, dtype=tf.float32, shape=[HIDDEN_NODE]))
        l1 = tf.nn.relu(tf.nn.xw_plus_b(input_tensor, weight, bias))
    with tf.name_scope('Layer-2'):
        weight = tf.Variable(tf.truncated_normal(shape=[HIDDEN_NODE, OUTPUT_NODE], stddev=0.1, dtype=tf.float32),
                             name='weight')
        bias = tf.Variable(tf.constant(0, dtype=tf.float32, shape=[OUTPUT_NODE]))
        l2 = tf.nn.xw_plus_b(l1, weight, bias)
    return l2


def train():
    x = tf.placeholder(dtype=tf.float32, shape=[None, INPUT_NODE], name='x-input')
    y_ = tf.placeholder(dtype=tf.float32, shape=[None, OUTPUT_NODE], name='y-input')
    global_step = tf.Variable(0, trainable=False)
    logits = inference(x)
    loss = tf.reduce_mean(tf.square(y_ - logits))# 均方误差
    train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(loss, global_step=global_step)
    train_x, train_y = gen_data()
    np.random.seed(200)
    shuffle_index = np.random.permutation(train_x.shape[0])  #
    shuffled_X = train_x[shuffle_index]
    shuffled_y = train_y[shuffle_index]

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(train_x, train_y, lw=5, c='r')
    plt.ion()
    plt.show()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(500000):
            feed_dic = {x: shuffled_X, y_: shuffled_y}
            _, train_loss = sess.run([train_step, loss], feed_dict=feed_dic)
            if (i + 1) % 80 == 0:
                print('loss at train data:  ', train_loss)
                try:
                    ax.lines.remove(lines[0])
                except Exception:
                    pass
                y_pre = sess.run(logits, feed_dict={x: train_x})
                lines = ax.plot(train_x, y_pre,c='black')
                plt.pause(0.1)

if __name__ == '__main__':
    train()

结果:

mark
mark
mark

猜你喜欢

转载自blog.csdn.net/The_lastest/article/details/82848257
今日推荐