《统计学习方法》第九章,EM算法

▶ EM 算法的引入,三硬币问题,体验一下不同初始值对收敛点的影响

● 代码

 1 import numpy as np
 2 import matplotlib.pyplot as plt
 3 from matplotlib.patches import Rectangle
 4 
 5 dataSize = 1000
 6 trainDataRatio = 0.3
 7 defaultTurn = 20
 8 epsilon = 1E-10
 9 randomSeed = 103
10 
11 def dataSplit(dataY, part):                                     # 将数据集分割为训练集和测试集
12     return dataY[:part], dataY[part:]
13 
14 def createData(realA, realB, realC, count = dataSize):          # 创建数据
15     np.random.seed(randomSeed)
16     a = (np.random.rand(count) > realA).astype(int)
17     b = (np.random.rand(count) > realB).astype(int)
18     c = (np.random.rand(count) > realC).astype(int)
19     return b * (1 - a) + c * a
20 
21 def em(dataY, initialA, initialB, initialC, turn = defaultTurn):# 迭代计算
22     count = len(dataY)
23     sumY = np.sum(dataY)
24     a = initialA
25     b = initialB
26     c = initialC
27     for i in range(turn):
28         p = a * b ** dataY *(1 - b) ** (1 - dataY) / ( a * b ** dataY *(1 - b) ** (1 - dataY) + (1 - a) * c ** dataY *(1 - c) ** (1 - dataY) )
29         sumP = np.sum(p)
30         a = sumP / count
31         b = np.sum(p * dataY) / sumP
32         c = (sumY - np.sum(p * dataY)) / (count - sumP)
33     return a, b, c
34 
35 def test(realA, realB, realC, initialA, initialB, initialC):    # 单次测试
36     Y = createData(realA, realB, realC)
37 
38     para = em(Y, initialA, initialB, initialC)
39 
40     print( "real=(%.3f, %.3f, %.3f),initial=(%.3f,%.3f,%.3f),train=(%.3f,%.3f,%.3f)"%(realA, realB, realC, initialA,initialB,initialC,para[0],para[1],para[2]) )
41 
42 if __name__ == '__main__':
43     test(0.5, 0.5, 0.5, 0.5, 0.5, 0.5)
44     test(0.5, 0.5, 0.5, epsilon, epsilon, epsilon)
45     test(0.5, 0.5, 0.5, 0.5, epsilon, epsilon)
46     test(0.5, 0.5, 0.5, epsilon, 0.5, epsilon)
47     test(0.5, 0.5, 0.5, epsilon, epsilon, 0.5)
48     test(0.5, 0.5, 0.5, 1.0 - epsilon, epsilon, epsilon)
49     test(0.5, 0.5, 0.5, epsilon, 1.0 - epsilon, epsilon)
50     test(0.5, 0.5, 0.5, epsilon, epsilon, 1.0 - epsilon)
51     test(0.5, 0.5, 0.5, 1.0 - epsilon, 1.0 - epsilon, 1.0 - epsilon)
52 
53     test(0.4, 0.5, 0.6, 0.4, 0.5, 0.6)
54     test(0.4, 0.5, 0.6, epsilon, epsilon, epsilon)
55     test(0.5, 0.5, 0.5, 0.5, epsilon, epsilon)
56     test(0.5, 0.5, 0.5, epsilon, 0.5, epsilon)
57     test(0.5, 0.5, 0.5, epsilon, epsilon, 0.5)

● 输出结果,从不同的真实值和初始值得到不同的收敛点

real=(0.500, 0.500, 0.500),initial=(0.500,0.500,0.500),train=(0.500,0.516,0.516)
real=(0.500, 0.500, 0.500),initial=(0.000,0.000,0.000),train=(0.000,0.516,0.516)
real=(0.500, 0.500, 0.500),initial=(0.500,0.000,0.000),train=(0.500,0.516,0.516)
real=(0.500, 0.500, 0.500),initial=(0.000,0.500,0.000),train=(0.172,1.000,0.415)
real=(0.500, 0.500, 0.500),initial=(0.000,0.000,0.500),train=(0.000,0.000,0.516)
real=(0.500, 0.500, 0.500),initial=(1.000,0.000,0.000),train=(1.000,0.516,0.516)
real=(0.500, 0.500, 0.500),initial=(0.000,1.000,0.000),train=(0.258,1.000,0.348)
real=(0.500, 0.500, 0.500),initial=(0.000,0.000,1.000),train=(0.242,0.000,0.681)
real=(0.500, 0.500, 0.500),initial=(1.000,1.000,1.000),train=(1.000,0.516,0.516)
real=(0.400, 0.500, 0.600),initial=(0.400,0.500,0.600),train=(0.409,0.406,0.506)
real=(0.400, 0.500, 0.600),initial=(0.000,0.000,0.000),train=(0.000,0.465,0.465)
real=(0.500, 0.500, 0.500),initial=(0.500,0.000,0.000),train=(0.500,0.516,0.516)
real=(0.500, 0.500, 0.500),initial=(0.000,0.500,0.000),train=(0.172,1.000,0.415)
real=(0.500, 0.500, 0.500),initial=(0.000,0.000,0.500),train=(0.000,0.000,0.516)

● 画图,散点位置表示初始取值,散点颜色 RGB 值表示收敛点取值。各图依次为:(真实值 ( 0.5,0.5,0.5 ),初始间隔 0.1,迭代 20 次),(真实值 ( 0.5,0.5,0.5 ),初始间隔 0.1,迭代 100 次),(真实值 ( 0.5,0.5,0.5 ),初始间隔 0.05,迭代 20 次),(真实值 ( 0.3,0.6,0.8 ),初始间隔 0.1,迭代 20 次)。可见:① 迭代 20 次以后就基本稳定了,更多次数迭代没有明显影响;② 随着初始点的连续移动,收敛点的取值耶连续漂移,没有出现明显断层;③ 图中色彩饱和度较高的散点存在,说明收敛点并不能向真实值点明显靠拢,甚至有可能保持极端取值;④ 真实值点对收敛点在整个空间上的取值有影响(废话)

● 画图脚本

 1 import numpy as np
 2 import matplotlib.pyplot as plt
 3 from matplotlib.patches import Rectangle
 4 from mpl_toolkits.mplot3d import Axes3D
 5 from mpl_toolkits.mplot3d.art3d import Poly3DCollection
 6 
 7 dataSize = 1000
 8 trainDataRatio = 0.3
 9 defaultTurn = 20
10 epsilon = 1E-10
11 randomSeed = 103
12 
13 def dataSplit(dataY, part):
14     return dataY[:part], dataY[part:]
15 
16 def myColor(x):
17     r = np.select([x < 1/2, x < 3/4, x <= 1, True],[0, 4 * x - 2, 1, 0])
18     g = np.select([x < 1/4, x < 3/4, x <= 1, True],[4 * x, 1, 4 - 4 * x, 0])
19     b = np.select([x < 1/4, x < 1/2, x <= 1, True],[1, 2 - 4 * x, 0, 0])
20     return [r,g,b]
21 
22 def createData(realA, realB, realC, count = dataSize):
23     np.random.seed(randomSeed)
24     a = (np.random.rand(count) > realA).astype(int)
25     b = (np.random.rand(count) > realB).astype(int)
26     c = (np.random.rand(count) > realC).astype(int)
27     return b * (1 - a) + c * a
28 
29 def em(dataY, initialA, initialB, initialC, turn = defaultTurn):
30     count = len(dataY)
31     sumY = np.sum(dataY)
32     a = initialA
33     b = initialB
34     c = initialC
35     for i in range(turn):
36         p = a * b ** dataY *(1 - b) ** (1 - dataY) / ( a * b ** dataY *(1 - b) ** (1 - dataY) + (1 - a) * c ** dataY *(1 - c) ** (1 - dataY) )
37         sumP = np.sum(p)
38         a = sumP / count
39         b = np.sum(p * dataY) / sumP
40         c = (sumY - np.sum(p * dataY)) / (count - sumP)
41     return a, b, c
42 
43 def test(realA, realB, realC):
44     dataY = createData(realA, realB, realC)
45     XX, YY, ZZ = np.meshgrid(np.arange(0.1,1.00,0.1), np.arange(0.1,1.00,0.1), np.arange(0.1,1.00,0.1))
46     #XX, YY = np.meshgrid(np.arange(0.05,1.00,0.05), np.arange(0.05,1.00,0.05)) # 一个斜截平面
47     #ZZ = ( 9 - 5 * XX - 4 * YY ) / 12
48 
49     fig = plt.figure(figsize=(10, 8))
50     ax = Axes3D(fig)
51     ax.set_xlim3d(0.0, 1.0)
52     ax.set_ylim3d(0.0, 1.0)
53     ax.set_zlim3d(0.0, 1.0)
54     ax.set_xlabel('X', fontdict={'size': 15, 'color': 'r'})
55     ax.set_ylabel('Y', fontdict={'size': 15, 'color': 'g'})
56     ax.set_zlabel('Z', fontdict={'size': 15, 'color': 'b'})
57 
58     for xyz in zip(XX.flatten(),YY.flatten(),ZZ.flatten()):
59         para = em(dataY, xyz[0], xyz[1], xyz[2])
60         para = np.minimum(np.maximum(np.array(para),0),1)
61         ax.scatter([xyz[0]], [xyz[1]], [xyz[2]], color = list(para), s = 20, label = "P")
62         #ax.scatter([xyz[0]], [xyz[1]], [xyz[2]], color = myColor( np.sum((np.array(para) - np.array([realA,realB,realC]))**2)), s = 20, label = "P")
63 
64     fig.savefig("R:\\(" + str(round(realA,3)) + "," + str(round(realB,3)) + "," + str(round(realC,3)) + ").png")
65     plt.close()
66 
67 if __name__ == '__main__':
68     test(0.5, 0.5, 0.5)

猜你喜欢

转载自www.cnblogs.com/cuancuancuanhao/p/11305578.html