之前我们已经了解了从数据集构造决策树的各种子功能模块,原理:从原始数据中基于最好的特征值进行划分数据集,由于特征值可能多余两个,所以可能存在大于两个分支的数据集划分。第一次划分之后数据将被传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以用递归的原则处理数据。
递归结束的条件是:程序遍历完所有划分数据集的属性,或则每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或终止模块。任何到达叶子节点的数据必然属于叶子节点的分类。
第一个条件可以使算法终止,我们在算法开始之前计算列的数目,查看算法是否使用了所有属性即可。如果数据集已经处理了所有属性,但是类标签不唯一,此时我们需要决定如何定义该叶子节点,这种情况下通常使用多数表决的方法决定叶子节点的分类。
def majorityCnt(classList): #存储每类标签出现的频率,按照从小到大的顺序进行排列 classCount = {} for vote in classList: if vote not in classList.keys(): classCount[vote] += 1 classCount[vote] += 1 sortedclassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedclassCount
创建树的函数代码:
def createTree(dataSet, labels):
'''
输入参数:数据集和标签列表
算法本身并不需要这个变量,为了给出数据明确的含义,我们将他作为一个输入参数提供给
'''
#所有数据集的类标签
classList = [example[-1] for example in dataSet]
#第一个停止条件:所有类标签完全相同,直接返回该类标签
if classList.count(classList[0]) == len(classList):
return classList[0]
#使用完了所有特征仍不能将数据集划分成一类,返回出现次数最多的类别作为返回值
if len(dataSet[0]) == 1:
return majorityCnt(classList)
#当前数据集中选取的最好的特征值 返回值为特征值的索引
bestFeat = chooseBestFeatureToSplit(dataSet)
print(bestFeat)
#最好特征的特征名称
bestFeatLabel = labels[bestFeat]
print(bestFeatLabel)
myTree = {bestFeatLabel:{}}
print(myTree)
#分类结束后删除当前特征
del(labels[bestFeat])
#遍历所有样本集(数据集)中的最好特征对应的特征值
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
#特征对应的特征值
for value in uniqueVals:
'''
当函数参数是列表类型时,参数是按照引用方式传递的。
为了保证每次调用函数createTree()时不改变原始列表的内容,使用新变量代替原始列表
'''
#复制了所有的特征名称(这里是我觉得和书不一样的地方)
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
print(myTree)
return myTree
执行上述语句,我们可以的到如下的结果:
变量myTree包含了很多代表树结构的嵌套字典。从左边开始,第一个关键字:no surfacing是第一个划分数据集的特征名字,该关键字的值也是另一个数据字典。第二个关键字是no surfacing特征划分的数据集,这些关键字的值都是no surfacing节点的子节点。这些值可以是类标签,也可以是另一个数据字典。如果值是类标签,则该节点是叶子节点,如果值是另一个数据字典,那么子节点是一个判断节点。这种不断重复的结构就构成了整棵树。
下面我们将使用matplotlib注解来绘制树形图:
1、matplotlib注解
import matplotlib.pyplot as plt #定义树节点格式的常量 #文本框和箭头格式 decisionNode = dict(boxstyle="sawtooth", fc="0.8") #boxstyle="sawtooth"边框线是波浪线 fc注解框的颜色深度 leafNode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") #执行绘图功能 #绘制带箭头的注解 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) #xy=parentPt起点位置, #xytext=centerPt注解框位置 #创建一个新的绘图区,在上面绘制两个不同类型的树节点 def createPlot(): fig = plt.figure(1, facecolor='white') fig.clf() createPlot.ax1 = plt.subplot(111, frameon=False) plotNode('a decision node', (0.5,0.1), (0.1,0.5), decisionNode) plotNode('a leaf node', (0.8,0.1), (0.3,0.8), leafNode) plt.show()
我们掌握了绘制绘制树节点的方法后,下面将学习如何绘制整棵树。
2、构造注解树
我们需要知道有多少个叶节点,以便正确的确定x轴的长度;我们还需要知道有多少层,以便确定y的高度。我们通过定义两个函数来确定叶节点的数目和树的层数。
我们使用如下两个函数来获取叶节点的数目和树的层数。
def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0]
#python3与python2的区别,python3要把键变成一个列表#print(firstStr) secondDict = myTree[firstStr] #print(secondDict) for key in secondDict.keys(): #print(key) if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumleafs(secondDict[key]) else: numLeafs += 1 return numLeafs
def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1+getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth:maxDepth = thisDepth return maxDepth
我们用一个函数预先输出存储树的信息:
def retrieveTree(i): listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}] return listOfTrees[i]
我们现在将之前所学的方法组合在一起,绘制一棵完整的树:
def plotMidText(cntrPt, parentPt, txtString): #在父子节点间填充文本信息 xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on ''' cntrPt用来记录当前要画的树的树根的结点位置 ''' #计算宽高 numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] #python3与python2的区别,python3要把键变成一个列表 ''' 我们希望树根在这些所有叶子节点的中间位置 这里的 1.0 + numLeafs 需要拆开来理解,也就是 plotTree.xOff + float(numLeafs)/2.0/plotTree.totalW +1.0/2.0/plotTree.totalW plotTree.xOff + 1/2 * float(numLeafs)/plotTree.totalW + 0.5/plotTree.totalW 因为xOff的初始值是-0.5/plotTree.totalW ,是往左偏了0.5/plotTree.tatalW 的,这里正好加回去。 这样cntrPt记录的x坐标正好是所有叶子结点的中心点 ''' cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #标记子节点的属性 plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] #减少y偏移 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) #yOff的初始值为1,每向下递归一次,这个值减去 1 / totalD plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) ''' xOff和yOff用来记录当前要画的叶子结点的位置。 画布的范围x轴和y轴都是0到1,我们希望所有的叶子结点平均分布在x轴上。 totalW记录叶子结点的个数,那么 1/totalW 正好是每个叶子结点的宽度。 如果叶子结点的坐标是 1/totalW , 2/totalW, 3/totalW, …, 1 的话,就正好在宽度的最右边, 为了让坐标在宽度的中间,需要减去0.5 / totalW 。初始化 plotTree.xOff 的值为-0.5/plotTree.totalW。 这样每次 xOff + 1/totalW ,正好是下1个结点的准确位置 ''' plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()输出效果如下:
。
以上就是关于从原始数据中创建决策树并用python库来绘制树形图。