机器学习第5章第1节(下) : 针对两类函数训练神经网络

机器学习第5章第1节(下) : 针对两类函数训练神经网络

思路

使用一系列二维数据对神经网络进行训练

代码

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
针对两类函数训练神经网络

将一组(x,y)值划分为以下两类函数之一
    第一类: y = 2x + 1
    第二类: y = 7x + 1

@author: Oscar
"""
import pylab as pl
import numpy as np

#学习速率
a = 0.3
#偏置
b = 1

#训练数据
x =np.array ([
      [1,1,3], #y = 2x + 1
      [1,2,5], #y = 2x + 1
      [1,1,8], #y = 7x + 1
      [1,2,15],#y = 7x + 1
      [1,3,7], #y = 2x + 1
      [1,4,29] #y = 7x + 1

      ])

#期望输出
d = np.array([
        1,  #训练数据的第一类
        1,  #训练数据的第一类
        -1, #训练数据的第二类
        -1, #训练数据的第二类
        1,  #训练数据的第一类
        -1  #训练数据的第二类

        ])

#权值
weight = np.array([b,0,0])

#验证用的数据
test_data1 = np.array([b,9,19]) #第一类
test_data2 = np.array([b,9,64]) #第二类

"""
感知器:
    +1 : 第一类
    -1 : 第二类
"""
def sgn(v):
    if v >= 0:
        return +1
    else:
        return -1


"""
计算感知器所需要的参数的值,
然后把值传递给感知器,
返回分类的值
"""
def comy(current_weight,current_x):
    return sgn(np.dot(current_weight.T,current_x))

"""
更新权值
"""
def update_weight(old_weight,current_d,current_x,a):
    return old_weight + a*(current_d - comy(old_weight,current_x))*current_x

#开始训练
print("开始训练,当前权值:",weight)
i = 0
for xn in x :
    weight = update_weight(weight,d[i],xn,a)
    i += 1
    print("第",i,"轮训练,当前使用的数据:",xn," 当前权值:",weight)

#训练完毕,查看数据分类情况
print("----------------------------------------------------")
for xn in x:
    print("x=",xn[1],", y=",xn[2],"=>",comy(weight,xn))
print("----------------------------------------------------")

#训练完毕,开始验证
print("当前测试数据:",test_data1," ,分类结果:",comy(weight,test_data1))
print("当前测试数据:",test_data2," ,分类结果:",comy(weight,test_data2))

"""
运行结果:

开始训练,当前权值: [1 0 0]
第 1 轮训练,当前使用的数据: [1 1 3]  当前权值: [ 1.  0.  0.]
第 2 轮训练,当前使用的数据: [1 2 5]  当前权值: [ 1.  0.  0.]
第 3 轮训练,当前使用的数据: [1 1 8]  当前权值: [ 0.4 -0.6 -4.8]
第 4 轮训练,当前使用的数据: [ 1  2 15]  当前权值: [ 0.4 -0.6 -4.8]
第 5 轮训练,当前使用的数据: [1 3 7]  当前权值: [ 1.   1.2 -0.6]
第 6 轮训练,当前使用的数据: [ 1  4 29]  当前权值: [ 1.   1.2 -0.6]
----------------------------------------------------
x= 1 , y= 3 => 1
x= 2 , y= 5 => 1
x= 1 , y= 8 => -1
x= 2 , y= 15 => -1
x= 3 , y= 7 => 1
x= 4 , y= 29 => -1
----------------------------------------------------
当前测试数据: [ 1  9 19]  ,分类结果: 1
当前测试数据: [ 1  9 64]  ,分类结果: -1
"""

#------------------------根据训练结果绘制可视化图-----------------------------------#
#数据点的x轴和y轴坐标
point_x = x[:,1] 
point_y = x[:,2]
#准备绘制
pl.subplot(111)

#坐标轴的最大最小值
x_max = np.max(point_x) + 15
x_min = np.min(point_x) - 5
y_max = np.max(point_y) + 50
y_min = np.min(point_y) - 5

#设置x轴的标签和最大最小值
pl.xlabel(u"x")
pl.xlim(x_min,x_max)

#设置y轴的标签和最大最小值
pl.ylabel(u"y")
pl.ylim(y_min,y_max)

#绘制训练数据
for i in range(0,len(d)):
    if d[i] > 0:
        #绘制成红色星号
        pl.plot(point_x[i],point_y[i],"r*")
    else:
        #绘制成红色圆点
        pl.plot(point_x[i],point_y[i],"ro")

#绘制测试点1
if comy(weight,test_data1) > 0:
    #绘制成蓝色的小点
    pl.plot(test_data1[1],test_data1[2],"b.")
else:
    #绘制成蓝色的小叉
    pl.plot(test_data1[1],test_data1[2],"bx")


#绘制测试点2
if comy(weight,test_data2) > 0:
    #绘制成蓝色的小点
    pl.plot(test_data2[1],test_data2[2],"b.")
else:
    #绘制成蓝色的小叉
    pl.plot(test_data2[1],test_data2[2],"bx")


#绘制测试点3
test_data3 = [b,9,60]
if comy(weight,test_data3) > 0:
    #绘制成蓝色的小点
    pl.plot(test_data3[1],test_data3[2],"b.")
else:
    #绘制成蓝色的小叉
    pl.plot(test_data3[1],test_data3[2],"bx")

#绘制分类线
line_x = np.array(range(0,20))
"""
因为权值为[ 1.   1.2 -0.6]
所以可以得出:
    1.2x - 0.6y + 1 = 0
整理一下:
    1.2x + 1 = 0.6y
可以得出神经网络的分类线的大致方程为:
    y = 2x + 1.68

"""
line_y = 2 * line_x + 1.68

#绘制成绿色线条
pl.plot(line_x,line_y,"g--")

#显示图像
pl.show()

运行结果

开始训练,当前权值: [1 0 0]
第 1 轮训练,当前使用的数据: [1 1 3]  当前权值: [ 1.  0.  0.]
第 2 轮训练,当前使用的数据: [1 2 5]  当前权值: [ 1.  0.  0.]
第 3 轮训练,当前使用的数据: [1 1 8]  当前权值: [ 0.4 -0.6 -4.8]
第 4 轮训练,当前使用的数据: [ 1  2 15]  当前权值: [ 0.4 -0.6 -4.8]
第 5 轮训练,当前使用的数据: [1 3 7]  当前权值: [ 1.   1.2 -0.6]
第 6 轮训练,当前使用的数据: [ 1  4 29]  当前权值: [ 1.   1.2 -0.6]
----------------------------------------------------
x= 1 , y= 3 => 1
x= 2 , y= 5 => 1
x= 1 , y= 8 => -1
x= 2 , y= 15 => -1
x= 3 , y= 7 => 1
x= 4 , y= 29 => -1
----------------------------------------------------
当前测试数据: [ 1  9 19]  ,分类结果: 1
当前测试数据: [ 1  9 64]  ,分类结果: -1

plot

发布了57 篇原创文章 · 获赞 1690 · 访问量 76万+

猜你喜欢

转载自blog.csdn.net/u013733326/article/details/78667247