《机器学习实战笔记--第一部分 分类算法:决策树 2》

    之前我们已经了解了从数据集构造决策树的各种子功能模块,原理:从原始数据中基于最好的特征值进行划分数据集,由于特征值可能多余两个,所以可能存在大于两个分支的数据集划分。第一次划分之后数据将被传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以用递归的原则处理数据。

    递归结束的条件是:程序遍历完所有划分数据集的属性,或则每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或终止模块。任何到达叶子节点的数据必然属于叶子节点的分类。

    第一个条件可以使算法终止,我们在算法开始之前计算列的数目,查看算法是否使用了所有属性即可。如果数据集已经处理了所有属性,但是类标签不唯一,此时我们需要决定如何定义该叶子节点,这种情况下通常使用多数表决的方法决定叶子节点的分类。

def majorityCnt(classList):
    #存储每类标签出现的频率,按照从小到大的顺序进行排列
    classCount = {}
    for vote in classList:
        if vote not in classList.keys():
            classCount[vote] += 1
        classCount[vote] += 1
    sortedclassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedclassCount

    创建树的函数代码:

    

def createTree(dataSet, labels):
    '''
    输入参数:数据集和标签列表
    算法本身并不需要这个变量,为了给出数据明确的含义,我们将他作为一个输入参数提供给
    '''
    #所有数据集的类标签
    classList = [example[-1] for example in dataSet]
    
    #第一个停止条件:所有类标签完全相同,直接返回该类标签
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    
    #使用完了所有特征仍不能将数据集划分成一类,返回出现次数最多的类别作为返回值
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    
    #当前数据集中选取的最好的特征值 返回值为特征值的索引
    bestFeat = chooseBestFeatureToSplit(dataSet)
    print(bestFeat)
    #最好特征的特征名称
    bestFeatLabel = labels[bestFeat]
    print(bestFeatLabel)
    
    myTree = {bestFeatLabel:{}}
    print(myTree)
    #分类结束后删除当前特征
    del(labels[bestFeat])
    #遍历所有样本集(数据集)中的最好特征对应的特征值
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    #特征对应的特征值
    for value in uniqueVals:
        '''
        当函数参数是列表类型时,参数是按照引用方式传递的。
        为了保证每次调用函数createTree()时不改变原始列表的内容,使用新变量代替原始列表
        '''
        #复制了所有的特征名称(这里是我觉得和书不一样的地方)
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
        print(myTree)
    return myTree

    执行上述语句,我们可以的到如下的结果:

    

    变量myTree包含了很多代表树结构的嵌套字典。从左边开始,第一个关键字:no surfacing是第一个划分数据集的特征名字,该关键字的值也是另一个数据字典。第二个关键字是no surfacing特征划分的数据集,这些关键字的值都是no surfacing节点的子节点。这些值可以是类标签,也可以是另一个数据字典。如果值是类标签,则该节点是叶子节点,如果值是另一个数据字典,那么子节点是一个判断节点。这种不断重复的结构就构成了整棵树。

    下面我们将使用matplotlib注解来绘制树形图:

    1、matplotlib注解

    注解工具:annotations,可以在数据图形上添加文本注释。

    

import matplotlib.pyplot as plt
#定义树节点格式的常量
#文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8") #boxstyle="sawtooth"边框线是波浪线 fc注解框的颜色深度
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

#执行绘图功能
#绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
    #xy=parentPt起点位置, #xytext=centerPt注解框位置

#创建一个新的绘图区,在上面绘制两个不同类型的树节点    
def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotNode('a decision node', (0.5,0.1), (0.1,0.5), decisionNode)
    plotNode('a leaf node', (0.8,0.1), (0.3,0.8), leafNode)
    plt.show()
    
    

    我们掌握了绘制绘制树节点的方法后,下面将学习如何绘制整棵树。

 2、构造注解树

    我们需要知道有多少个叶节点,以便正确的确定x轴的长度;我们还需要知道有多少层,以便确定y的高度。我们通过定义两个函数来确定叶节点的数目和树的层数。    

    我们使用如下两个函数来获取叶节点的数目和树的层数。

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0] 
  #python3与python2的区别,python3要把键变成一个列表
 #print(firstStr) secondDict = myTree[firstStr] #print(secondDict) for key in secondDict.keys(): #print(key) if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumleafs(secondDict[key]) else: numLeafs += 1 return numLeafs
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1+getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:maxDepth = thisDepth
    return maxDepth   

我们用一个函数预先输出存储树的信息:

def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': 
                        {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]

我们现在将之前所学的方法组合在一起,绘制一棵完整的树:

def plotMidText(cntrPt, parentPt, txtString):
    #在父子节点间填充文本信息
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)
    
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    '''
    cntrPt用来记录当前要画的树的树根的结点位置
    '''
    #计算宽高
    numLeafs = getNumLeafs(myTree)  
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]   #python3与python2的区别,python3要把键变成一个列表
    '''
    我们希望树根在这些所有叶子节点的中间位置
    这里的 1.0 + numLeafs 需要拆开来理解,也就是
    plotTree.xOff +  float(numLeafs)/2.0/plotTree.totalW +1.0/2.0/plotTree.totalW
    plotTree.xOff +  1/2 * float(numLeafs)/plotTree.totalW + 0.5/plotTree.totalW
    因为xOff的初始值是-0.5/plotTree.totalW ,是往左偏了0.5/plotTree.tatalW 的,这里正好加回去。
    这样cntrPt记录的x坐标正好是所有叶子结点的中心点
    '''
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    #标记子节点的属性
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    #减少y偏移
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    #yOff的初始值为1,每向下递归一次,这个值减去 1 / totalD
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    '''
    xOff和yOff用来记录当前要画的叶子结点的位置。
    画布的范围x轴和y轴都是0到1,我们希望所有的叶子结点平均分布在x轴上。
    totalW记录叶子结点的个数,那么 1/totalW 正好是每个叶子结点的宽度。
    如果叶子结点的坐标是 1/totalW , 2/totalW, 3/totalW, …, 1 的话,就正好在宽度的最右边,
    为了让坐标在宽度的中间,需要减去0.5 / totalW 。初始化 plotTree.xOff 的值为-0.5/plotTree.totalW。
    这样每次 xOff + 1/totalW ,正好是下1个结点的准确位置
    '''
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
    输出效果如下:

    

    以上就是关于从原始数据中创建决策树并用python库来绘制树形图。


    

猜你喜欢

转载自blog.csdn.net/qq_41635352/article/details/80004635