附上实现的ID3算法python代码~~~
参考机器学习实战写的
#-*- coding: UTF-8 -*- from math import log import operator from matplotlib.font_manager import FontProperties import matplotlib.pyplot as plt import copy #创建测试数据 def createDataSet(): dataSet=[['young', 0, 0, 0, 'no'], #数据集,no代表不给贷款,yes代表给贷款 ['young', 0, 0, 1, 'no'], ['young', 1, 0, 1, 'yes'], ['young', 1, 1, 0, 'yes'], ['young', 0, 0, 0, 'no'], ['middle', 0, 0, 0, 'no'], ['middle', 0, 0, 1, 'no'], ['middle', 1, 1, 1, 'yes'], ['middle', 0, 1, 2, 'yes'], ['middle', 0, 1, 2, 'yes'], ['old', 0, 1, 2, 'yes'], ['old', 0, 1, 1, 'yes'], ['old', 1, 0, 1, 'yes'], ['old', 1, 0, 2, 'yes'], ['old', 0, 0, 0, 'no']] labels=['年龄','有工作','有房子','贷款情况']#贷款情况,0,1,2代表一般,好,非常好 return dataSet,labels #计算信息熵 def calShannonEnt(dataSet): labelCounts={} for item in dataSet: label=item[-1] if(label not in labelCounts.keys()): labelCounts[label]=1 else: labelCounts[label]+=1 length=len(dataSet) shannonEnt=0.0 for i in labelCounts: p=labelCounts[i]/length shannonEnt-=p*log(p,2) return shannonEnt ###按照给定的特征划分数据集 def splitDataSet(dataSet,index,value):#index为特征的索引,value为要选出的特征值: returnData=[] for item in dataSet: if(item[index]==value): item2=item[:index] item2.extend(item[index+1:]) returnData.append(item2) return returnData ###选择最优特征 def chooseBestFeatureToSplit(dataSet): featureNum=len(dataSet[0])-1 baseEnt=calShannonEnt(dataSet) maxGain=0.0 bestFeature=-1 for i in range(featureNum): #先统计i列特征有几种取值 featureValues=[] currentEnt=0.0 for item in dataSet: featureValues.append(item[i]) featureValues=set(featureValues) #对每种取值进行数据划分并计算熵 for value in featureValues: splitData=splitDataSet(dataSet,i,value) p=len(splitData)/len(dataSet) ent=calShannonEnt(splitData) currentEnt+=p*ent currentGain=baseEnt-currentEnt print("第%d个特征的增益为%.3f" % (i, currentGain)) if(maxGain<currentGain): maxGain=currentGain bestFeature=i return bestFeature ###统计classList中出现此处最多的元素 def majorityCnt(classList): classCount={} for item in classList: if item not in classCount.keys(): classCount[item]=1 else: classCountp[item]+=1 sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) print(sortedClassCount) return sortedClassCount[0][0] #创建决策树 """ 函数说明:创建决策树 Parameters: dataSet - 训练数据集 labels - 分类属性标签 featLabels - 存储选择的最优特征标签 在构建决策树的代码,可以看到,有个featLabels参数。 它是用来干什么的?它就是用来记录各个分类结点的,在用决策树做预测的时候,我们按顺序输入需要的分类结点的属性值即可。 Returns: myTree - 决策树 """ def createTree(dataSet,labels,featLabels): classList=[example[-1] for example in dataSet] if(classList.count(classList[0])==len(classList)): return classList[0] if(len(dataSet[0])==1): return majorityCnt(classList) bestFeat=chooseBestFeatureToSplit(dataSet) bestFeatLabel=labels[bestFeat] featLabels.append(bestFeatLabel) myTree={bestFeatLabel:{}} del(labels[bestFeat]) #得到训练集中所有最优特征的属性值 featValues=[example[bestFeat] for example in dataSet] featValues=set(featValues) for value in featValues: myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),labels,featLabels) return myTree """ 函数说明:获取决策树叶子结点的数目 Parameters: myTree - 决策树 Returns: numLeafs - 决策树的叶子结点的数目 """ def getNumLeafs(myTree): numLeafs=0 firstStr=next(iter(myTree)) secondDict=myTree[firstStr] for key in secondDict.keys(): if(type(secondDict[key]).__name__=='dict'): numLeafs+=getNumLeafs(secondDict[key]) else: numLeafs+=1 return numLeafs """ 函数说明:获取决策树的层数 Parameters: myTree - 决策树 Returns: maxDepth - 决策树的层数 """ def getTreeDepth(myTree): maxDepth=0 firstStr=next(iter(myTree)) secondDict=myTree[firstStr] for key in secondDict.keys(): if(type(secondDict[key]).__name__=='dict'): thisDepth=1+getTreeDepth(secondDict[key]) else: thisDepth=1 if(thisDepth>maxDepth): maxDepth=thisDepth return maxDepth """ 函数说明:绘制结点 Parameters: nodeTxt - 结点名 centerPt - 文本位置 parentPt - 标注的箭头位置 nodeType - 结点格式 Returns: 无 """ def plotNode(nodeTxt,centerPt,parentPt,nodeType): arrow_args = dict(arrowstyle="<-") #定义箭头格式 font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14) #设置中文字体 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', #绘制结点 xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font) """ 函数说明:标注有向边属性值 Parameters: cntrPt、parentPt - 用于计算标注位置 txtString - 标注的内容 Returns: 无 """ def plotMidText(cntrPt,parentPt,txtString): #计算标注位置 xMid=(parentPt[0]-cntrPt[0])/2+cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) """ 函数说明:绘制决策树 Parameters: myTree - 决策树(字典) parentPt - 标注的内容 nodeTxt - 结点名 Returns: 无 """ def plotTree(myTree, parentPt, nodeTxt): decisionNode = dict(boxstyle="sawtooth", fc="0.8") #设置结点格式 leafNode = dict(boxstyle="round4", fc="0.8") #设置叶结点格式 numLeafs = getNumLeafs(myTree) #获取决策树叶结点数目,决定了树的宽度 depth = getTreeDepth(myTree) #获取决策树层数 firstStr = next(iter(myTree)) #下个字典 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #中心位置 plotMidText(cntrPt, parentPt, nodeTxt) #标注有向边属性值 plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制结点 secondDict = myTree[firstStr] #下一个字典,也就是继续绘制子结点 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #y偏移 for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点 plotTree(secondDict[key],cntrPt,str(key)) #不是叶结点,递归调用继续绘制 else: #如果是叶结点,绘制叶结点,并标注有向边属性值 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD """ 函数说明:创建绘制面板 Parameters: inTree - 决策树(字典) Returns: 无 """ def createPlot(inTree): fig = plt.figure(1, facecolor='white') #创建fig fig.clf() #清空fig axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #去掉x、y轴 plotTree.totalW = float(getNumLeafs(inTree)) #获取决策树叶结点数目 plotTree.totalD = float(getTreeDepth(inTree)) #获取决策树层数 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; #x偏移 plotTree(inTree, (0.5,1.0), '') #绘制决策树 plt.show() """ 函数说明:使用决策树分类 Parameters: inputTree - 已经生成的决策树 featLabels - 特征标签 testVec - 测试数据列表 Returns: classLabel - 分类结果 """ def classify(inputTree,featLabels,testVec): firstStr=next(iter(inputTree)) secondDict=inputTree[firstStr] featIndex=featLabels.index(firstStr) for key in secondDict.keys(): if(testVec[featIndex]==key): if(type(secondDict[key]).__name__=='dict'): classLabel=classify(secondDict[key],featLabels,testVec) else: classLabel=secondDict[key] return classLabel if __name__=='__main__': dataSet,labels=createDataSet() labelTemp=copy.copy(labels) print(dataSet) print(calShannonEnt(dataSet)) print("最优特征索引值:"+str(chooseBestFeatureToSplit(dataSet))) featLabels=[] myTree=createTree(dataSet,labels,featLabels) print(myTree) createPlot(myTree) testVec = [0,1,0,1] #测试数据 result = classify(myTree, labelTemp, testVec) if result == 'yes': print('放贷') if result == 'no': print('不放贷')