决策树

1.概念:决策树经常用于处理分类问题,也是最经常使用的数据挖掘算法。决策树的一个重要任务是为了数据中所蕴含的知识信息,并从中提取一系列的规则,而创建这些规则的过程就是机器学习的过程。

2.优缺点:

 

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据;

 

缺点:可能会产生过度匹配问题;

 

适用数据类型:数值型和标称型

  1 #coding='utf-8'
  2 from math import log
  3 import operator
  4 import matplotlib.pyplot as plt
  5 # import treePlotter
  6 #计算给定数据集的香农熵
  7 def calcShannoEnt(dataSet):
  8     numEntries=len(dataSet)
  9     labelCounts = {}
 10     for featVec in dataSet:
 11         currentLabel = featVec[-1]
 12         if currentLabel not in labelCounts.keys():
 13             labelCounts[currentLabel] = 0
 14         labelCounts[currentLabel] += 1
 15     shannoEnt = 0.0
 16     for key in labelCounts:
 17         prob = float(labelCounts[key])/numEntries
 18         shannoEnt -= prob * log(prob,2)
 19     return shannoEnt
 20 
 21 
 22 #按照给定特征划分数据集
 23 def splitDataSet(dataSet, axis, value):
 24     retDataSet = []
 25     for featVec in dataSet:
 26         if featVec[axis] == value:
 27             reducedFeatVec = featVec[:axis]
 28             reducedFeatVec.extend(featVec[axis+1:])
 29             retDataSet.append(reducedFeatVec)
 30     return retDataSet
 31 
 32 #创建数据集
 33 def createDataSet():
 34     dataSet = [[1, 1, 'yes'],
 35                [1, 1, 'yes'],
 36                [1, 0, 'no'],
 37                [0, 1, 'no'],
 38                [0, 1, 'no']]
 39     # labels = ['no surfacing','flippers']
 40     labels=['no surfacing','flippers']
 41     return dataSet,labels
 42 
 43 #选择最好的数据集划分方式
 44 def chooseBestFeatureToSplit(dataSet):
 45     baseEntropy = calcShannoEnt(dataSet)
 46     bestInfoGain = 0.0
 47     bestFeature = -1
 48     numFeatures = len(dataSet[0])-1
 49     numEntries = len(dataSet)# 5
 50     for i in range(numFeatures):
 51         featList = [example[i] for example in dataSet]
 52         uniqueVals = set(featList)
 53         newEntropy = 0.0
 54         for value in uniqueVals:
 55             subDataSet = splitDataSet(dataSet,i,value)
 56             print("subDataSet",subDataSet)
 57             prob = len(subDataSet)/float(len(dataSet))
 58             newEntropy += prob* calcShannoEnt(subDataSet)
 59         infoGain = baseEntropy - newEntropy
 60         if (infoGain > bestInfoGain):
 61             bestInfoGain = infoGain
 62             bestFeature = i
 63     return bestFeature
 64 
 65 #投票表决
 66 def majorityCnt(classList):
 67     classCount = {}
 68     for vote in classList:
 69         if vote not in classCount.keys():
 70             classCount[vote] = 0
 71         classCount[vote] += 1
 72     sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
 73     return sortedClassCount[0][0]
 74 
 75 #创建树的函数
 76 def createTree(dataSet,labels):
 77     classList = [example[-1] for example in dataSet]
 78     if classList.count(classList[0]) == len(classList):#若类别完全相同则停止继续划分
 79         return classList[0]
 80     if len(dataSet[0]) == 1:
 81         return majorityCnt(classList)#遍历完所有特征时返回出现次数最多的
 82     bestFeat = chooseBestFeatureToSplit(dataSet)#0
 83     bestFeatLabel = labels[bestFeat]
 84     myTree = {bestFeatLabel:{}}
 85     del(labels[bestFeat])
 86     featValues = [example[bestFeat] for example in dataSet]
 87     uniqueVals = set(featValues)
 88     for value in uniqueVals:
 89         subLabels = labels[:]
 90         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
 91     return myTree
 92 
 93 #绘制带箭头的注解
 94 def plotNode(nodeTxt,centerPt,parentPt,nodeType):
 95     createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va='center',
 96                             ha='center',bbox=nodeType,arrowprops=arrow_args)
 97 
 98 # def createPlot():
 99 #     fig = plt.figure(1,facecolor='white')
100 #     fig.clf()
101 #     createPlot.ax1 = plt.subplot(111,frameon = 'white')
102 #     plotNode(u'决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
103 #     plotNode(u'叶节点',(0.8,0.1),(0.3,0.8),leafNode)
104 #     plt.show()
105 
106 #获取叶节点的数目
107 def getNumLeafs(myTree):
108     numLeafs = 0
109     firstStr = list(myTree.keys())[0]#0
110     secondDict = myTree[firstStr]#'no'
111     for key in secondDict.keys():
112         if type(secondDict[key]).__name__ == 'dict':
113             numLeafs += getNumLeafs(secondDict[key])
114         else:
115             numLeafs += 1
116     return numLeafs
117 
118 
119 #获取树的层数
120 def getTreeDepth(myTree):
121     TreeDepth = 0
122     maxDepth = 0
123     firstStr = list (myTree.keys ())[0]  # 0
124     secondDict = myTree[firstStr]  # 'no'
125     for key in secondDict.keys ():
126         if type (secondDict[key]).__name__ == 'dict':
127             TreeDepth += getTreeDepth (secondDict[key])
128         else:
129             TreeDepth = 1
130         if TreeDepth > maxDepth:
131             maxDepth = TreeDepth
132     return maxDepth
133 
134 #在父子节点间填充文本信息
135 def plotMidText(cntrPt, parentPt, txtString):
136     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
137     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
138     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
139 
140 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
141     numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
142     depth = getTreeDepth(myTree)
143     firstStr = list(myTree.keys())[0]     #the text label for this node should be this
144     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
145     plotMidText(cntrPt, parentPt, nodeTxt)
146     plotNode(firstStr, cntrPt, parentPt, decisionNode)
147     secondDict = myTree[firstStr]
148     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
149     for key in secondDict.keys():
150         if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
151             plotTree(secondDict[key],cntrPt,str(key))        #recursion
152         else:   #it's a leaf node print the leaf node
153             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
154             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
155             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
156     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
157 #if you do get a dictonary you know it's a tree, and the first element will be another dict
158 
159 #绘树形图
160 def createPlot(inTree):
161     fig = plt.figure(1, facecolor='white')
162     fig.clf()
163     axprops = dict(xticks=[], yticks=[])
164     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
165     #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
166     plotTree.totalW = float(getNumLeafs(inTree))
167     plotTree.totalD = float(getTreeDepth(inTree))
168     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
169     plotTree(inTree, (0.5,1.0), '')
170     plt.show()
171 
172 def retrieveTree(i):
173     listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
174                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
175                   ]
176     return listOfTrees[i]
177 
178 if __name__=="__main__":
179     sum = 0
180     dataSet = [[1, 1, 'yes'],
181                [1, 1, 'yes'],
182                [1, 0, 'no'],
183                [0, 1, 'no'],
184                [0, 1, 'no']]
185 
186     # dataSet_len = len(dataSet[0])
187     # for i in range(len):
188     #     shannoEnt = calcShannoEnt(numEntries,dataSet,i)
189     #     print("shannoEnt",shannoEnt)
190     # print(sum)
191     # result = chooseBestFeatureToSplit(dataSet)
192     # print(result)
193     # dataSet,labels=createDataSet()
194     # myTree = createTree(dataSet, labels)
195     # print(myTree)
196 
197     #定义文本框和箭头格式
198     decisionNode = dict(boxstyle = "sawtooth",fc="0.8")
199     leafNode = dict(boxstyle = "round4",fc="0.8")
200     arrow_args = dict(arrowstyle = "<-")
201 
202     # createPlot()
203     # numLeafs = getNumLeafs(myTree)
204     # print("numLeafs",numLeafs)
205     #
206     # TreeDepth = getTreeDepth (myTree)
207     # print("TreeDepth",TreeDepth)
208     myTree = retrieveTree(1)
209     createPlot(myTree)

 

结果:

猜你喜欢

转载自www.cnblogs.com/nxf-rabbit75/p/8908963.html