▶ EM 算法的引入,三硬币问题,体验一下不同初始值对收敛点的影响
● 代码
1 import numpy as np 2 import matplotlib.pyplot as plt 3 from matplotlib.patches import Rectangle 4 5 dataSize = 1000 6 trainDataRatio = 0.3 7 defaultTurn = 20 8 epsilon = 1E-10 9 randomSeed = 103 10 11 def dataSplit(dataY, part): # 将数据集分割为训练集和测试集 12 return dataY[:part], dataY[part:] 13 14 def createData(realA, realB, realC, count = dataSize): # 创建数据 15 np.random.seed(randomSeed) 16 a = (np.random.rand(count) > realA).astype(int) 17 b = (np.random.rand(count) > realB).astype(int) 18 c = (np.random.rand(count) > realC).astype(int) 19 return b * (1 - a) + c * a 20 21 def em(dataY, initialA, initialB, initialC, turn = defaultTurn):# 迭代计算 22 count = len(dataY) 23 sumY = np.sum(dataY) 24 a = initialA 25 b = initialB 26 c = initialC 27 for i in range(turn): 28 p = a * b ** dataY *(1 - b) ** (1 - dataY) / ( a * b ** dataY *(1 - b) ** (1 - dataY) + (1 - a) * c ** dataY *(1 - c) ** (1 - dataY) ) 29 sumP = np.sum(p) 30 a = sumP / count 31 b = np.sum(p * dataY) / sumP 32 c = (sumY - np.sum(p * dataY)) / (count - sumP) 33 return a, b, c 34 35 def test(realA, realB, realC, initialA, initialB, initialC): # 单次测试 36 Y = createData(realA, realB, realC) 37 38 para = em(Y, initialA, initialB, initialC) 39 40 print( "real=(%.3f, %.3f, %.3f),initial=(%.3f,%.3f,%.3f),train=(%.3f,%.3f,%.3f)"%(realA, realB, realC, initialA,initialB,initialC,para[0],para[1],para[2]) ) 41 42 if __name__ == '__main__': 43 test(0.5, 0.5, 0.5, 0.5, 0.5, 0.5) 44 test(0.5, 0.5, 0.5, epsilon, epsilon, epsilon) 45 test(0.5, 0.5, 0.5, 0.5, epsilon, epsilon) 46 test(0.5, 0.5, 0.5, epsilon, 0.5, epsilon) 47 test(0.5, 0.5, 0.5, epsilon, epsilon, 0.5) 48 test(0.5, 0.5, 0.5, 1.0 - epsilon, epsilon, epsilon) 49 test(0.5, 0.5, 0.5, epsilon, 1.0 - epsilon, epsilon) 50 test(0.5, 0.5, 0.5, epsilon, epsilon, 1.0 - epsilon) 51 test(0.5, 0.5, 0.5, 1.0 - epsilon, 1.0 - epsilon, 1.0 - epsilon) 52 53 test(0.4, 0.5, 0.6, 0.4, 0.5, 0.6) 54 test(0.4, 0.5, 0.6, epsilon, epsilon, epsilon) 55 test(0.5, 0.5, 0.5, 0.5, epsilon, epsilon) 56 test(0.5, 0.5, 0.5, epsilon, 0.5, epsilon) 57 test(0.5, 0.5, 0.5, epsilon, epsilon, 0.5)
● 输出结果,从不同的真实值和初始值得到不同的收敛点
real=(0.500, 0.500, 0.500),initial=(0.500,0.500,0.500),train=(0.500,0.516,0.516) real=(0.500, 0.500, 0.500),initial=(0.000,0.000,0.000),train=(0.000,0.516,0.516) real=(0.500, 0.500, 0.500),initial=(0.500,0.000,0.000),train=(0.500,0.516,0.516) real=(0.500, 0.500, 0.500),initial=(0.000,0.500,0.000),train=(0.172,1.000,0.415) real=(0.500, 0.500, 0.500),initial=(0.000,0.000,0.500),train=(0.000,0.000,0.516) real=(0.500, 0.500, 0.500),initial=(1.000,0.000,0.000),train=(1.000,0.516,0.516) real=(0.500, 0.500, 0.500),initial=(0.000,1.000,0.000),train=(0.258,1.000,0.348) real=(0.500, 0.500, 0.500),initial=(0.000,0.000,1.000),train=(0.242,0.000,0.681) real=(0.500, 0.500, 0.500),initial=(1.000,1.000,1.000),train=(1.000,0.516,0.516) real=(0.400, 0.500, 0.600),initial=(0.400,0.500,0.600),train=(0.409,0.406,0.506) real=(0.400, 0.500, 0.600),initial=(0.000,0.000,0.000),train=(0.000,0.465,0.465) real=(0.500, 0.500, 0.500),initial=(0.500,0.000,0.000),train=(0.500,0.516,0.516) real=(0.500, 0.500, 0.500),initial=(0.000,0.500,0.000),train=(0.172,1.000,0.415) real=(0.500, 0.500, 0.500),initial=(0.000,0.000,0.500),train=(0.000,0.000,0.516)
● 画图,散点位置表示初始取值,散点颜色 RGB 值表示收敛点取值。各图依次为:(真实值 ( 0.5,0.5,0.5 ),初始间隔 0.1,迭代 20 次),(真实值 ( 0.5,0.5,0.5 ),初始间隔 0.1,迭代 100 次),(真实值 ( 0.5,0.5,0.5 ),初始间隔 0.05,迭代 20 次),(真实值 ( 0.3,0.6,0.8 ),初始间隔 0.1,迭代 20 次)。可见:① 迭代 20 次以后就基本稳定了,更多次数迭代没有明显影响;② 随着初始点的连续移动,收敛点的取值耶连续漂移,没有出现明显断层;③ 图中色彩饱和度较高的散点存在,说明收敛点并不能向真实值点明显靠拢,甚至有可能保持极端取值;④ 真实值点对收敛点在整个空间上的取值有影响(废话)
● 画图脚本
1 import numpy as np 2 import matplotlib.pyplot as plt 3 from matplotlib.patches import Rectangle 4 from mpl_toolkits.mplot3d import Axes3D 5 from mpl_toolkits.mplot3d.art3d import Poly3DCollection 6 7 dataSize = 1000 8 trainDataRatio = 0.3 9 defaultTurn = 20 10 epsilon = 1E-10 11 randomSeed = 103 12 13 def dataSplit(dataY, part): 14 return dataY[:part], dataY[part:] 15 16 def myColor(x): 17 r = np.select([x < 1/2, x < 3/4, x <= 1, True],[0, 4 * x - 2, 1, 0]) 18 g = np.select([x < 1/4, x < 3/4, x <= 1, True],[4 * x, 1, 4 - 4 * x, 0]) 19 b = np.select([x < 1/4, x < 1/2, x <= 1, True],[1, 2 - 4 * x, 0, 0]) 20 return [r,g,b] 21 22 def createData(realA, realB, realC, count = dataSize): 23 np.random.seed(randomSeed) 24 a = (np.random.rand(count) > realA).astype(int) 25 b = (np.random.rand(count) > realB).astype(int) 26 c = (np.random.rand(count) > realC).astype(int) 27 return b * (1 - a) + c * a 28 29 def em(dataY, initialA, initialB, initialC, turn = defaultTurn): 30 count = len(dataY) 31 sumY = np.sum(dataY) 32 a = initialA 33 b = initialB 34 c = initialC 35 for i in range(turn): 36 p = a * b ** dataY *(1 - b) ** (1 - dataY) / ( a * b ** dataY *(1 - b) ** (1 - dataY) + (1 - a) * c ** dataY *(1 - c) ** (1 - dataY) ) 37 sumP = np.sum(p) 38 a = sumP / count 39 b = np.sum(p * dataY) / sumP 40 c = (sumY - np.sum(p * dataY)) / (count - sumP) 41 return a, b, c 42 43 def test(realA, realB, realC): 44 dataY = createData(realA, realB, realC) 45 XX, YY, ZZ = np.meshgrid(np.arange(0.1,1.00,0.1), np.arange(0.1,1.00,0.1), np.arange(0.1,1.00,0.1)) 46 #XX, YY = np.meshgrid(np.arange(0.05,1.00,0.05), np.arange(0.05,1.00,0.05)) # 一个斜截平面 47 #ZZ = ( 9 - 5 * XX - 4 * YY ) / 12 48 49 fig = plt.figure(figsize=(10, 8)) 50 ax = Axes3D(fig) 51 ax.set_xlim3d(0.0, 1.0) 52 ax.set_ylim3d(0.0, 1.0) 53 ax.set_zlim3d(0.0, 1.0) 54 ax.set_xlabel('X', fontdict={'size': 15, 'color': 'r'}) 55 ax.set_ylabel('Y', fontdict={'size': 15, 'color': 'g'}) 56 ax.set_zlabel('Z', fontdict={'size': 15, 'color': 'b'}) 57 58 for xyz in zip(XX.flatten(),YY.flatten(),ZZ.flatten()): 59 para = em(dataY, xyz[0], xyz[1], xyz[2]) 60 para = np.minimum(np.maximum(np.array(para),0),1) 61 ax.scatter([xyz[0]], [xyz[1]], [xyz[2]], color = list(para), s = 20, label = "P") 62 #ax.scatter([xyz[0]], [xyz[1]], [xyz[2]], color = myColor( np.sum((np.array(para) - np.array([realA,realB,realC]))**2)), s = 20, label = "P") 63 64 fig.savefig("R:\\(" + str(round(realA,3)) + "," + str(round(realB,3)) + "," + str(round(realC,3)) + ").png") 65 plt.close() 66 67 if __name__ == '__main__': 68 test(0.5, 0.5, 0.5)