神经网络-损失函数:

神经网络中的损失函数:

1.神经网络的复杂度:多用神经网络的层数和神经网络的参数的个数表示。

2.层数=隐藏层的层数+1个输出层,总参数=总的权重w与总的偏置项b。

3.自定义损失函数:

通过每个需要预测的结果y与标准答案y_比较,计算其损失累计和,即loss(y_,y)=sum{f(y_,y)},其中y_是标准答案数据集的,y为预测答案计算出的。

函数f定义如下 当y<y_时,f(y_,y)=w1*(y_-y);y>=y_时, f(y_,y)=w1*(y_-y);其中w1和w2为两个相互的矛盾,通过这样权衡,以达到最优。

4.在tensorflow中通过tf.reduce_sum()实现:

loss=tf.reduce_sum(tf.where(greater(y,y_),w1*(y_-y),w1*(y_-y))),其中tf.greater(y,y_)判断两参数大小,然后根据判断结果选择计算式子。

5.交叉熵(Cross Entropy):表征两个概率分布之间的距离;

H(y_,y)=-Sum(y_*logy),y_表示标准答案的概率分布,y表示预测结果的概率分布。通过交叉熵可以判断那个预测结果与标准答案更接近。

6.在tensorflow中通过tf.reduce_mean(y_*tf.log(tf.clip_by_value(y,1e-12,1.0)))实现,当y小于1e-12时,y为1e-12,大于1.0时,y为1.0。


 

7.当n分类的n个输出(y1,y2,...,yn)通过softmax()函数,便满足了概率分布要求:

任意x,P(X=x)属于[0,1]且sum(P(X=x))=1

8.softmax(y_i)={e^{y_i}}/sum{^n,_j=1}{e^{y_i}},

ce=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))

cem=tf.reduce_mean(ce)

实现了先使数据满足概率分布,在计算交叉熵的方法。

9.常用的激活函数:

f(x)=0,x小于等于0时, f(x)=x,x大于0时,即tf.relu();

f(x)=1/(1+e^{-x}),即tf.nn.sigmoid();

f(x)=(1-e^{-2x})/(1+e^{-2x}),即tf.nn.tanh();

#实例应用

#导入模块,生成数据集

import tensorflow as tf

import numpy as np

SIZE=8

seed=23455

COST = 1

PROFIT=9

rdm =np.random.RandomState(seed)

X=rdm.rand(32,2)

Y=[[x1+x2+(rdm.rand()/10.0-0.05)] for (x1,x2) in X]#合格为1,不合格为0


 

#定义神经网络的输入,参数和输出,定义前向传播的过程

x=tf.compat.v1.placeholder(tf.float32,shape=(None,2))

y_=tf.compat.v1.placeholder(tf.float32,shape=(None,1))

w1=tf.Variable(tf.random.normal([2,1],stddev=1,seed=1))

y=tf.matmul(x,w1)

#定义损失函数及反向传播方法

#loss_mse=tf.reduce_mean(tf.square(y_-y))

loss=tf.reduce_sum(tf.where(tf.greater(y,y_),(y-y_)*COST,(y_-y)*PROFIT))

train_step=tf.compat.v1.train.GradientDescentOptimizer(0.001).minimize(loss)



 

#生成会话,训练

with tf.compat.v1.Session() as sess:

init_op=tf.compat.v1.global_variables_initializer()

sess.run(init_op)

steps=2000

for i in range(steps):

start=(i * SIZE) % 32

end=start+SIZE

sess.run(train_step,feed_dict={x:X[start:end],y_:Y[start:end]})

if i % 500 == 0:

print(i)

print(sess.run(w1))

print('\n')

print("训练结束参数w1:\n",sess.run(w1))


 

发布了89 篇原创文章 · 获赞 8 · 访问量 8893

猜你喜欢

转载自blog.csdn.net/TxyITxs/article/details/101547337
今日推荐