1.概念:决策树经常用于处理分类问题,也是最经常使用的数据挖掘算法。决策树的一个重要任务是为了数据中所蕴含的知识信息,并从中提取一系列的规则,而创建这些规则的过程就是机器学习的过程。
2.优缺点:
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据;
缺点:可能会产生过度匹配问题;
适用数据类型:数值型和标称型
1 #coding='utf-8' 2 from math import log 3 import operator 4 import matplotlib.pyplot as plt 5 # import treePlotter 6 #计算给定数据集的香农熵 7 def calcShannoEnt(dataSet): 8 numEntries=len(dataSet) 9 labelCounts = {} 10 for featVec in dataSet: 11 currentLabel = featVec[-1] 12 if currentLabel not in labelCounts.keys(): 13 labelCounts[currentLabel] = 0 14 labelCounts[currentLabel] += 1 15 shannoEnt = 0.0 16 for key in labelCounts: 17 prob = float(labelCounts[key])/numEntries 18 shannoEnt -= prob * log(prob,2) 19 return shannoEnt 20 21 22 #按照给定特征划分数据集 23 def splitDataSet(dataSet, axis, value): 24 retDataSet = [] 25 for featVec in dataSet: 26 if featVec[axis] == value: 27 reducedFeatVec = featVec[:axis] 28 reducedFeatVec.extend(featVec[axis+1:]) 29 retDataSet.append(reducedFeatVec) 30 return retDataSet 31 32 #创建数据集 33 def createDataSet(): 34 dataSet = [[1, 1, 'yes'], 35 [1, 1, 'yes'], 36 [1, 0, 'no'], 37 [0, 1, 'no'], 38 [0, 1, 'no']] 39 # labels = ['no surfacing','flippers'] 40 labels=['no surfacing','flippers'] 41 return dataSet,labels 42 43 #选择最好的数据集划分方式 44 def chooseBestFeatureToSplit(dataSet): 45 baseEntropy = calcShannoEnt(dataSet) 46 bestInfoGain = 0.0 47 bestFeature = -1 48 numFeatures = len(dataSet[0])-1 49 numEntries = len(dataSet)# 5 50 for i in range(numFeatures): 51 featList = [example[i] for example in dataSet] 52 uniqueVals = set(featList) 53 newEntropy = 0.0 54 for value in uniqueVals: 55 subDataSet = splitDataSet(dataSet,i,value) 56 print("subDataSet",subDataSet) 57 prob = len(subDataSet)/float(len(dataSet)) 58 newEntropy += prob* calcShannoEnt(subDataSet) 59 infoGain = baseEntropy - newEntropy 60 if (infoGain > bestInfoGain): 61 bestInfoGain = infoGain 62 bestFeature = i 63 return bestFeature 64 65 #投票表决 66 def majorityCnt(classList): 67 classCount = {} 68 for vote in classList: 69 if vote not in classCount.keys(): 70 classCount[vote] = 0 71 classCount[vote] += 1 72 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) 73 return sortedClassCount[0][0] 74 75 #创建树的函数 76 def createTree(dataSet,labels): 77 classList = [example[-1] for example in dataSet] 78 if classList.count(classList[0]) == len(classList):#若类别完全相同则停止继续划分 79 return classList[0] 80 if len(dataSet[0]) == 1: 81 return majorityCnt(classList)#遍历完所有特征时返回出现次数最多的 82 bestFeat = chooseBestFeatureToSplit(dataSet)#0 83 bestFeatLabel = labels[bestFeat] 84 myTree = {bestFeatLabel:{}} 85 del(labels[bestFeat]) 86 featValues = [example[bestFeat] for example in dataSet] 87 uniqueVals = set(featValues) 88 for value in uniqueVals: 89 subLabels = labels[:] 90 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels) 91 return myTree 92 93 #绘制带箭头的注解 94 def plotNode(nodeTxt,centerPt,parentPt,nodeType): 95 createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center', 96 ha='center',bbox=nodeType,arrowprops=arrow_args) 97 98 # def createPlot(): 99 # fig = plt.figure(1,facecolor='white') 100 # fig.clf() 101 # createPlot.ax1 = plt.subplot(111,frameon = 'white') 102 # plotNode(u'决策节点',(0.5,0.1),(0.1,0.5),decisionNode) 103 # plotNode(u'叶节点',(0.8,0.1),(0.3,0.8),leafNode) 104 # plt.show() 105 106 #获取叶节点的数目 107 def getNumLeafs(myTree): 108 numLeafs = 0 109 firstStr = list(myTree.keys())[0]#0 110 secondDict = myTree[firstStr]#'no' 111 for key in secondDict.keys(): 112 if type(secondDict[key]).__name__ == 'dict': 113 numLeafs += getNumLeafs(secondDict[key]) 114 else: 115 numLeafs += 1 116 return numLeafs 117 118 119 #获取树的层数 120 def getTreeDepth(myTree): 121 TreeDepth = 0 122 maxDepth = 0 123 firstStr = list (myTree.keys ())[0] # 0 124 secondDict = myTree[firstStr] # 'no' 125 for key in secondDict.keys (): 126 if type (secondDict[key]).__name__ == 'dict': 127 TreeDepth += getTreeDepth (secondDict[key]) 128 else: 129 TreeDepth = 1 130 if TreeDepth > maxDepth: 131 maxDepth = TreeDepth 132 return maxDepth 133 134 #在父子节点间填充文本信息 135 def plotMidText(cntrPt, parentPt, txtString): 136 xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] 137 yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] 138 createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) 139 140 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on 141 numLeafs = getNumLeafs(myTree) #this determines the x width of this tree 142 depth = getTreeDepth(myTree) 143 firstStr = list(myTree.keys())[0] #the text label for this node should be this 144 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) 145 plotMidText(cntrPt, parentPt, nodeTxt) 146 plotNode(firstStr, cntrPt, parentPt, decisionNode) 147 secondDict = myTree[firstStr] 148 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD 149 for key in secondDict.keys(): 150 if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes 151 plotTree(secondDict[key],cntrPt,str(key)) #recursion 152 else: #it's a leaf node print the leaf node 153 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW 154 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 155 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 156 plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD 157 #if you do get a dictonary you know it's a tree, and the first element will be another dict 158 159 #绘树形图 160 def createPlot(inTree): 161 fig = plt.figure(1, facecolor='white') 162 fig.clf() 163 axprops = dict(xticks=[], yticks=[]) 164 createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks 165 #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 166 plotTree.totalW = float(getNumLeafs(inTree)) 167 plotTree.totalD = float(getTreeDepth(inTree)) 168 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; 169 plotTree(inTree, (0.5,1.0), '') 170 plt.show() 171 172 def retrieveTree(i): 173 listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, 174 {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} 175 ] 176 return listOfTrees[i] 177 178 if __name__=="__main__": 179 sum = 0 180 dataSet = [[1, 1, 'yes'], 181 [1, 1, 'yes'], 182 [1, 0, 'no'], 183 [0, 1, 'no'], 184 [0, 1, 'no']] 185 186 # dataSet_len = len(dataSet[0]) 187 # for i in range(len): 188 # shannoEnt = calcShannoEnt(numEntries,dataSet,i) 189 # print("shannoEnt",shannoEnt) 190 # print(sum) 191 # result = chooseBestFeatureToSplit(dataSet) 192 # print(result) 193 # dataSet,labels=createDataSet() 194 # myTree = createTree(dataSet, labels) 195 # print(myTree) 196 197 #定义文本框和箭头格式 198 decisionNode = dict(boxstyle = "sawtooth",fc="0.8") 199 leafNode = dict(boxstyle = "round4",fc="0.8") 200 arrow_args = dict(arrowstyle = "<-") 201 202 # createPlot() 203 # numLeafs = getNumLeafs(myTree) 204 # print("numLeafs",numLeafs) 205 # 206 # TreeDepth = getTreeDepth (myTree) 207 # print("TreeDepth",TreeDepth) 208 myTree = retrieveTree(1) 209 createPlot(myTree)