神经网络优化(二) - 搭建神经网络八股

为提高程序的可复用性,搭建模块化的神经网络八股

1 前向传播

前向传播就是设计、搭建从输入(参数 x ) 到输出(返回值为预测或分类结果 y )的完整网络结构,实现前向传播过程,一般将其放在 forward.py 文件中

前向传播需要定义三个函数(实际上第一个函数是框架,第二、三个函数是赋初值过程)

def forward(x, regularizer):
    w = 
    b = 
    y = 
    return y

函数功能:

  • 定义前向传播过程,返回值为y
  • 完成网络结构的设计,实现从输入到输出的数据通路
  • regularizer 为正则化权重

def get_weight(shape, regularizer):
    w = tf.Variable()
    tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w 

函数功能:

  • 为 w 赋初值,
  • 把每一个 w 的正则化损失加到总损失losses中
  • 返回 w
def get_bias(shape):
    b = tf.Variable()
    return b

函数功能

  • 为 b 赋初值
  • shape的形状实际上就是某层 b 的个数

2 反向传播

反向传播是神经网络训练过程,优化神经网络参数,一般将其放在 backward.py 文件中

def backward():
    x = tf.placeholder()
    y_ = tf.placeholder()
    y = forward.forward(x, REGULARIZER)
    global_step = tf.Variable(0, trainable=False)
    loss =

函数功能:

  • backward 函数用来描述反向传播过程
  • placeholder 给 x、y_ 占位
  • 调用forward.forward()模块复现前向传播的网络结构,用于计算求算 y
  • 定义轮数计数器global_step
  • 定义损失函数loss
# 方案1 梯度下降
loss_mse = tf.reduce_mean(tf.square(y - y_))
loss = loss_mse + tf.add_n(tf.get_collection('losses'))

# 方案2 交叉熵
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection('losses'))

代码功能

  • 损失函数正则化
  • 首先定义损失函数 loss,也即 y 与 y_ 差距的描述方式
  • 方案1为梯度下降、方案2为交叉熵
  • 正则化损失函数  loss = loss_mse/ cem + tf.add_n(tf.get_collection('losses')) 
  • losses 的值在 w 赋初值时会有体现
    w = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
    tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
 
 
learning_rate = tf.train.exponential_decay(
    RATE_BASE,
    global_step,
    数据集总样本数/ BATCH_SIZE,
    RATE_DECAY,
    staircase=True
)
train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

函数功能:

  • 利用指数衰减学习率,动态计算学习率
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
    train_op = tf.no_op(name='train')

函数功能:

  • 滑动平均

建立会话

with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    for i in range(STEPS):
        sess.run(train_step, feed_dict={x:  , y_: })
        if i % 轮数 == 0:
            print()

with结构初始化所有参数,并调用训练函数,实现待优化参数训练过程。

运行文件是否为主文件

if name == '__main__':
    backward()

该部分用来判断 python 运行的文件是否为主文件。若是主文件,则执行 backword()函数。 

正则化 - 避免过拟合,提高泛化性

指数学习率 - 加快优化效率

3 代码示例

代码总共分三个文件:

  • 生成数据集 generateds.py
  • 前向传播 forward.py
  • 反向传播 backward.py

 3.1 生成数据集generateds.py

 1 # coding:utf-8
 2 # 导入模块 ,生成模拟数据集
 3 import numpy as np
 4 import matplotlib.pyplot as plt
 5 seed = 2
 6 def generateds():
 7     # 基于seed产生随机数
 8     rdm = np.random.RandomState(seed)
 9     # 随机数返回300行2列的矩阵,表示300组坐标点(x0,x1)作为输入数据集
10     X = rdm.randn(300,2)
11     # 从X这个300行2列的矩阵中取出一行,判断如果两个坐标的平方和小于2,给Y赋值1,其余赋值0
12     # 作为输入数据集的标签(正确答案)
13     Y_ = [int(x0*x0 + x1*x1 <2) for (x0,x1) in X]
14     # 遍历Y中的每个元素,1赋值'red'其余赋值'blue',这样可视化显示时人可以直观区分
15     Y_c = [['red' if y else 'blue'] for y in Y_]
16     # 对数据集X和标签Y进行形状整理,第一个元素为-1表示跟随第二列计算,第二个元素表示多少列,可见X为两列,Y为1列
17     X = np.vstack(X).reshape(-1,2)
18     Y_ = np.vstack(Y_).reshape(-1,1)
19        
20     # print(X)
21     # print(Y_)
22     # print(Y_c)
23     # # 用plt.scatter画出数据集X各行中第0列元素和第1列元素的点即各行的(x0,x1),用各行Y_c对应的值表示颜色(c是color的缩写) 
24     # plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c)) 
25     # plt.show()
26     return X, Y_, Y_c
27 # generateds()

将20、21、22、24、25、27行代码 “ 解禁 ” 就可以看到该文件所能得到的数据集(可视化)

运行的结果代码有

[[-4.16757847e-01 -5.62668272e-02]
 [-2.13619610e+00  1.64027081e+00]
 [-1.79343559e+00 -8.41747366e-01]
 [ 5.02881417e-01 -1.24528809e+00]
 [-1.05795222e+00 -9.09007615e-01]
 [ 5.51454045e-01  2.29220801e+00]
 [ 4.15393930e-02 -1.11792545e+00]
 [ 5.39058321e-01 -5.96159700e-01]
 [-1.91304965e-02  1.17500122e+00]
 [-7.47870949e-01  9.02525097e-03]
 [-8.78107893e-01 -1.56434170e-01]
 [ 2.56570452e-01 -9.88779049e-01]
 [-3.38821966e-01 -2.36184031e-01]
 [-6.37655012e-01 -1.18761229e+00]
 [-1.42121723e+00 -1.53495196e-01]
 [-2.69056960e-01  2.23136679e+00]
 [-2.43476758e+00  1.12726505e-01]
 [ 3.70444537e-01  1.35963386e+00]
 [ 5.01857207e-01 -8.44213704e-01]
 [ 9.76147160e-06  5.42352572e-01]
 [-3.13508197e-01  7.71011738e-01]
 [-1.86809065e+00  1.73118467e+00]
 [ 1.46767801e+00 -3.35677339e-01]
 [ 6.11340780e-01  4.79705919e-02]
 [-8.29135289e-01  8.77102184e-02]
 [ 1.00036589e+00 -3.81092518e-01]
 [-3.75669423e-01 -7.44707629e-02]
 [ 4.33496330e-01  1.27837923e+00]
 [-6.34679305e-01  5.08396243e-01]
 [ 2.16116006e-01 -1.85861239e+00]
 [-4.19316482e-01 -1.32328898e-01]
 [-3.95702397e-02  3.26003433e-01]
 [-2.04032305e+00  4.62555231e-02]
 [-6.77675577e-01 -1.43943903e+00]
 [ 5.24296430e-01  7.35279576e-01]
 [-6.53250268e-01  8.42456282e-01]
 [-3.81516482e-01  6.64890091e-02]
 [-1.09873895e+00  1.58448706e+00]
 [-2.65944946e+00 -9.14526229e-02]
 [ 6.95119605e-01 -2.03346655e+00]
 [-1.89469265e-01 -7.72186654e-02]
 [ 8.24703005e-01  1.24821292e+00]
 [-4.03892269e-01 -1.38451867e+00]
 [ 1.36723542e+00  1.21788563e+00]
 [-4.62005348e-01  3.50888494e-01]
 [ 3.81866234e-01  5.66275441e-01]
 [ 2.04207979e-01  1.40669624e+00]
 [-1.73795950e+00  1.04082395e+00]
 [ 3.80471970e-01 -2.17135269e-01]
 [ 1.17353150e+00 -2.34360319e+00]
 [ 1.16152149e+00  3.86078048e-01]
 [-1.13313327e+00  4.33092555e-01]
 [-3.04086439e-01  2.58529487e+00]
 [ 1.83533272e+00  4.40689872e-01]
 [-7.19253841e-01 -5.83414595e-01]
 [-3.25049628e-01 -5.60234506e-01]
 [-9.02246068e-01 -5.90972275e-01]
 [-2.76179492e-01 -5.16883894e-01]
 [-6.98589950e-01 -9.28891925e-01]
 [ 2.55043824e+00 -1.47317325e+00]
 [-1.02141473e+00  4.32395701e-01]
 [-3.23580070e-01  4.23824708e-01]
 [ 7.99179995e-01  1.26261366e+00]
 [ 7.51964849e-01 -9.93760983e-01]
 [ 1.10914328e+00 -1.76491773e+00]
 [-1.14421297e-01 -4.98174194e-01]
 [-1.06079904e+00  5.91666521e-01]
 [-1.83256574e-01  1.01985473e+00]
 [-1.48246548e+00  8.46311892e-01]
 [ 4.97940148e-01  1.26504175e-01]
 [-1.41881055e+00 -2.51774118e-01]
 [-1.54667461e+00 -2.08265194e+00]
 [ 3.27974540e+00  9.70861320e-01]
 [ 1.79259285e+00 -4.29013319e-01]
 [ 6.96197980e-01  6.97416272e-01]
 [ 6.01515814e-01  3.65949071e-03]
 [-2.28247558e-01 -2.06961226e+00]
 [ 6.10144086e-01  4.23496900e-01]
 [ 1.11788673e+00 -2.74242089e-01]
 [ 1.74181219e+00 -4.47500876e-01]
 [-1.25542722e+00  9.38163671e-01]
 [-4.68346260e-01 -1.25472031e+00]
 [ 1.24823646e-01  7.56502143e-01]
 [ 2.41439629e-01  4.97425649e-01]
 [ 4.10869262e+00  8.21120877e-01]
 [ 1.53176032e+00 -1.98584577e+00]
 [ 3.65053516e-01  7.74082033e-01]
 [-3.64479092e-01 -8.75979478e-01]
 [ 3.96520159e-01 -3.14617436e-01]
 [-5.93755583e-01  1.14950057e+00]
 [ 1.33556617e+00  3.02629336e-01]
 [-4.54227855e-01  5.14370717e-01]
 [ 8.29458431e-01  6.30621967e-01]
 [-1.45336435e+00 -3.38017777e-01]
 [ 3.59133332e-01  6.22220414e-01]
 [ 9.60781945e-01  7.58370347e-01]
 [-1.13431848e+00 -7.07420888e-01]
 [-1.22142917e+00  1.80447664e+00]
 [ 1.80409807e-01  5.53164274e-01]
 [ 1.03302907e+00 -3.29002435e-01]
 [-1.15100294e+00 -4.26522471e-01]
 [-1.48147191e-01  1.50143692e+00]
 [ 8.69598198e-01 -1.08709057e+00]
 [ 6.64221413e-01  7.34884668e-01]
 [-1.06136574e+00 -1.08516824e-01]
 [-1.85040397e+00  3.30488064e-01]
 [-3.15693210e-01 -1.35000210e+00]
 [-6.98170998e-01  2.39951198e-01]
 [-5.52949440e-01  2.99526813e-01]
 [ 5.52663696e-01 -8.40443012e-01]
 [-3.12270670e-01  2.14467809e+00]
 [ 1.21105582e-01 -8.46828752e-01]
 [ 6.04624490e-02 -1.33858888e+00]
 [ 1.13274608e+00  3.70304843e-01]
 [ 1.08580640e+00  9.02179395e-01]
 [ 3.90296450e-01  9.75509412e-01]
 [ 1.91573647e-01 -6.62209012e-01]
 [-1.02351498e+00 -4.48174823e-01]
 [-2.50545813e+00  1.82599446e+00]
 [-1.71406741e+00 -7.66395640e-02]
 [-1.31756727e+00 -2.02559359e+00]
 [-8.22453750e-02 -3.04666585e-01]
 [-1.59724130e-01  5.48946560e-01]
 [-6.18375485e-01  3.78794466e-01]
 [ 5.13251444e-01 -3.34844125e-01]
 [-2.83519516e-01  5.38424263e-01]
 [ 5.72509465e-02  1.59088487e-01]
 [-2.37440268e+00  5.85199353e-02]
 [ 3.76545911e-01 -1.35479764e-01]
 [ 3.35908395e-01  1.90437591e+00]
 [ 8.53644334e-02  6.65334278e-01]
 [-8.49995503e-01 -8.52341797e-01]
 [-4.79985112e-01 -1.01964910e+00]
 [-7.60113841e-03 -9.33830661e-01]
 [-1.74996844e-01 -1.43714343e+00]
 [-1.65220029e+00 -6.75661789e-01]
 [-1.06706712e+00 -6.52931145e-01]
 [-6.12094750e-01 -3.51262461e-01]
 [ 1.04547799e+00  1.36901602e+00]
 [ 7.25353259e-01 -3.59474459e-01]
 [ 1.49695179e+00 -1.53111111e+00]
 [-2.02336394e+00  2.67972576e-01]
 [-2.20644541e-03 -1.39291883e-01]
 [ 3.25654693e-02 -1.64056022e+00]
 [-1.15669917e+00  1.23403468e+00]
 [ 1.02818490e+00 -7.21879726e-01]
 [ 1.93315697e+00 -1.07079633e+00]
 [-5.71381608e-01  2.92432067e-01]
 [-1.19499989e+00 -4.87930544e-01]
 [-1.73071165e-01 -3.95346401e-01]
 [ 8.70840765e-01  5.92806797e-01]
 [-1.09929731e+00 -6.81530644e-01]
 [ 1.80066685e-01 -6.69310440e-02]
 [-7.87749540e-01  4.24753672e-01]
 [ 8.19885117e-01 -6.31118683e-01]
 [ 7.89059649e-01 -1.62167380e+00]
 [-1.61049926e+00  4.99939764e-01]
 [-8.34515207e-01 -9.96959687e-01]
 [-2.63388077e-01 -6.77360492e-01]
 [ 3.27067038e-01 -1.45535944e+00]
 [-3.71519124e-01  3.16096597e+00]
 [ 1.09951013e-01 -1.91352322e+00]
 [ 5.99820429e-01  5.49384465e-01]
 [ 1.38378103e+00  1.48349243e-01]
 [-6.53541444e-01  1.40883398e+00]
 [ 7.12061227e-01 -1.80071604e+00]
 [ 7.47598942e-01 -2.32897001e-01]
 [ 1.11064528e+00 -3.73338813e-01]
 [ 7.86146070e-01  1.94168696e-01]
 [ 5.86204098e-01 -2.03872918e-02]
 [-4.14408598e-01  6.73134124e-02]
 [ 6.31798924e-01  4.17592731e-01]
 [ 1.61517627e+00  4.25606211e-01]
 [ 6.35363758e-01  2.10222927e+00]
 [ 6.61264168e-02  5.35558351e-01]
 [-6.03140792e-01  4.19576292e-02]
 [ 1.64191464e+00  3.11697707e-01]
 [ 1.45116990e+00 -1.06492788e+00]
 [-1.40084545e+00  3.07525527e-01]
 [-1.36963867e+00  2.67033724e+00]
 [ 1.24845030e+00 -1.24572655e+00]
 [-1.67168774e-01 -5.76610930e-01]
 [ 4.16021749e-01 -5.78472626e-02]
 [ 9.31887358e-01  1.46833213e+00]
 [-2.21320943e-01 -1.17315562e+00]
 [ 5.62669078e-01 -1.64515057e-01]
 [ 1.14485538e+00 -1.52117687e-01]
 [ 8.29789046e-01  3.36065952e-01]
 [-1.89044051e-01 -4.49328601e-01]
 [ 7.13524448e-01  2.52973487e+00]
 [ 8.37615794e-01 -1.31682403e-01]
 [ 7.07592866e-01  1.14053878e-01]
 [-1.28089518e+00  3.09846277e-01]
 [ 1.54829069e+00 -3.15828043e-01]
 [-1.12590378e+00  4.88496666e-01]
 [ 1.83094666e+00  9.40175993e-01]
 [ 1.01871705e+00  2.30237829e+00]
 [ 1.62109298e+00  7.12683273e-01]
 [-2.08703629e-01  1.37617991e-01]
 [-1.03352168e-01  8.48350567e-01]
 [-8.83125561e-01  1.54538683e+00]
 [ 1.45840073e-01 -4.00106056e-01]
 [ 8.15206041e-01 -2.07492237e+00]
 [-8.34437391e-01 -6.57718447e-01]
 [ 8.20564332e-01 -4.89157001e-01]
 [ 1.42496703e+00 -4.46857897e-01]
 [ 5.21109431e-01 -7.08194380e-01]
 [ 1.15553059e+00 -2.54530459e-01]
 [ 5.18924924e-01 -4.92994911e-01]
 [-1.08654815e+00 -2.30917497e-01]
 [ 1.09801004e+00 -1.01787805e+00]
 [-1.52939136e+00 -3.07987737e-01]
 [ 7.80754356e-01 -1.05583964e+00]
 [-5.43883381e-01  1.84301739e-01]
 [-3.30675843e-01  2.87208202e-01]
 [ 1.18952814e+00  2.12015479e-02]
 [-6.54096803e-02  7.66115904e-01]
 [-6.16350846e-02 -9.52897152e-01]
 [-1.01446306e+00 -1.11526396e+00]
 [ 1.91260068e+00 -4.52632031e-02]
 [ 5.76909718e-01  7.17805695e-01]
 [-9.38998998e-01  6.28775807e-01]
 [-5.64493432e-01 -2.08780746e+00]
 [-2.15050132e-01 -1.07502856e+00]
 [-3.37972149e-01  3.43212732e-01]
 [ 2.28253964e+00 -4.95778848e-01]
 [-1.63962832e-01  3.71622161e-01]
 [ 1.86521520e-01 -1.58429224e-01]
 [-1.08292956e+00 -9.56625520e-01]
 [-1.83376735e-01 -1.15980690e+00]
 [-6.57768362e-01 -1.25144841e+00]
 [ 1.12448286e+00 -1.49783981e+00]
 [ 1.90201722e+00 -5.80383038e-01]
 [-1.05491567e+00 -1.18275720e+00]
 [ 7.79480054e-01  1.02659795e+00]
 [-8.48666001e-01  3.31539648e-01]
 [-1.49591353e-01 -2.42440600e-01]
 [ 1.51197175e-01  7.65069481e-01]
 [-1.91663052e+00 -2.22734129e+00]
 [ 2.06689897e-01 -7.08763560e-02]
 [ 6.84759969e-01 -1.70753905e+00]
 [-9.86569665e-01  1.54353634e+00]
 [-1.31027053e+00  3.63433972e-01]
 [-7.94872445e-01 -4.05286267e-01]
 [-1.37775793e+00  1.18604868e+00]
 [-1.90382114e+00 -1.19814038e+00]
 [-9.10065643e-01  1.17645419e+00]
 [ 2.99210670e-01  6.79267178e-01]
 [-1.76606800e-02  2.36040923e-01]
 [ 4.94035871e-01  1.54627765e+00]
 [ 2.46857508e-01 -1.46877580e+00]
 [ 1.14709994e+00  9.55569845e-02]
 [-1.10743873e+00 -1.76286141e-01]
 [-9.82755667e-01  2.08668273e+00]
 [-3.44623671e-01 -2.00207923e+00]
 [ 3.03234433e-01 -8.29874845e-01]
 [ 1.28876941e+00  1.34925462e-01]
 [-1.77860064e+00 -5.00791490e-01]
 [-1.08816157e+00 -7.57855553e-01]
 [-6.43744900e-01 -2.00878453e+00]
 [ 1.96262894e-01 -8.75896370e-01]
 [-8.93609209e-01  7.51902355e-01]
 [ 1.89693224e+00 -6.29079151e-01]
 [ 1.81208553e+00 -2.05626574e+00]
 [ 5.62704887e-01 -5.82070757e-01]
 [-7.40029749e-02 -9.86496364e-01]
 [-5.94722499e-01 -3.14811843e-01]
 [-3.46940532e-01  4.11443516e-01]
 [ 2.32639090e+00 -6.34053128e-01]
 [-1.54409962e-01 -1.74928880e+00]
 [-2.51957930e+00  1.39116243e+00]
 [-1.32934644e+00 -7.45596414e-01]
 [ 2.12608498e-02  9.10917515e-01]
 [ 3.15276082e-01  1.86620821e+00]
 [-1.82497623e-01 -1.82826634e+00]
 [ 1.38955717e-01  1.19450165e-01]
 [-8.18899200e-01 -3.32639265e-01]
 [-5.86387955e-01  1.73451634e+00]
 [-6.12751558e-01 -1.39344202e+00]
 [ 2.79433757e-01 -1.82223127e+00]
 [ 4.27017458e-01  4.06987749e-01]
 [-8.44308241e-01 -5.59820113e-01]
 [-6.00520405e-01  1.61487324e+00]
 [ 3.94953220e-01 -1.20381347e+00]
 [-1.24747243e+00 -7.75462496e-02]
 [-1.33397514e-02 -7.68323250e-01]
 [ 2.91234010e-01 -1.97330948e-01]
 [ 1.07682965e+00  4.37410232e-01]
 [-9.31978663e-02  1.35631416e-01]
 [-8.82708822e-01  8.84744194e-01]
 [ 3.83204463e-01 -4.16994149e-01]
 [ 1.17796550e-01 -5.36685309e-01]
 [ 2.48718458e+00 -4.51361054e-01]
 [ 5.18836127e-01  3.64448005e-01]
 [-7.98348729e-01  5.65779713e-03]
 [-3.20934708e-01  2.49513550e-01]
 [ 2.56308392e-01  7.67625083e-01]
 [ 7.83020087e-01 -4.07063047e-01]
 [-5.24891667e-01 -5.89808683e-01]
 [-8.62531086e-01 -1.74287290e+00]]
[[1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [1]
 [0]
 [0]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [0]
 [1]
 [0]
 [0]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]]
[['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['blue'], ['red'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['blue'], ['blue'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue'], ['red'], ['red'], ['red'], ['red'], ['red'], ['red'], ['blue']]
数据集的数据内容

3.2 前向传播forward.py

# coding:utf-8
# 导入模块 ,生成模拟数据集
import tensorflow as tf

# 定义神经网络的输入、参数和输出,定义前向传播过程
def get_weight(shape, regularizer):
    w = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
    tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w

def get_bias(shape): 
    b = tf.Variable(tf.constant(0.01, shape=shape)) 
    return b

def forward(x, regularizer):
    w1 = get_weight([2, 11], regularizer)
    b1 = get_bias([11])
    y1 = tf.nn.relu(tf.matmul(x, w1) + b1)

    w2 = get_weight([11, 1], regularizer)
    b2 = get_bias([1])
    y = tf.matmul(y1, w2) + b2 

    return y

3.3 反向传播过程backward.py

# coding:utf-8
# 0导入模块 ,生成模拟数据集
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import generateds
import forward

STEPS = 40000
BATCH_SIZE = 30
LEARNING_RATE_BASE = 0.001
LEARNING_RATE_DECAY = 0.999
REGULARIZER = 0.01

def backward():
    x = tf.placeholder(tf.float32, shape=(None, 2))
    y_ = tf.placeholder(tf.float32, shape=(None, 1))

    X, Y_, Y_c = generateds.generateds()

    y = forward.forward(x, REGULARIZER)

    global_step = tf.Variable(0,trainable=False)

    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        300/BATCH_SIZE,
        LEARNING_RATE_DECAY,
        staircase=True)


    # 定义损失函数
    loss_mse = tf.reduce_mean(tf.square(y-y_))
    loss_total = loss_mse + tf.add_n(tf.get_collection('losses'))
    
    # 定义反向传播方法:包含正则化
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss_total)

    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        for i in range(STEPS):
            start = (i*BATCH_SIZE) % 300
            end = start + BATCH_SIZE
            sess.run(train_step, feed_dict={x: X[start:end], y_:Y_[start:end]})
            if i % 2000 == 0:
                loss_v = sess.run(loss_total, feed_dict={x:X,y_:Y_})
                print("After %d steps, loss is: %f" %(i, loss_v))

        xx, yy = np.mgrid[-3:3:.01, -3:3:.01]
        grid = np.c_[xx.ravel(), yy.ravel()]
        probs = sess.run(y, feed_dict={x:grid})
        probs = probs.reshape(xx.shape)

    plt.scatter(X[:,0], X[:,1], c=np.squeeze(Y_c)) 
    plt.contour(xx, yy, probs, levels=[.5])
    plt.show()
    
if __name__=='__main__':
    backward()

运行

猜你喜欢

转载自www.cnblogs.com/gengyi/p/9905226.html