决策树—基本原理与实战

概念

        决策树(Decision Tree)是在已知各种情况发生概率的情况下,通过构成决策树来求取净现值的期望值大于0的概率,是直观运用概率分析的一种图解法。通俗的讲,决策树就是带有特殊含义的数据结构中的树结构,其每个根结点(非叶子结点)代表数据的特征标签,根据该特征不同的特征值将数据划分成几个子集,每个子集都是这个根结点的子树,然后对每个子树递归划分下去,而决策树的每个叶子结点则是数据的最终类别标签。对于一个样本特征向量,则从决策树的顶端往下进行分类,直到根结点,得到的类别标签就是这个样本向量的类别。

如:

特点

这里写图片描述

决策树的构造

        要构造决策树,就需要根据样本数据集的数据特征对数据集进行划分,直到针对所有特征都划分过,或者划分的数据子集的所有数据的类别标签相同。然而要构造决策树,面临的第一个问题是先对哪个特征进行划分,即当前数据集上哪个特征在划分数据分类时起决定性作用。因此,为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。

决策树的一般流程:

(1)收集数据:可以使用任何方法。

(2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。

(3)分析数据:可以使用任何方法,构造树完成后,我们应该检查图形是否符合预期。

(4)训练算法:构造树的数据结构

(5)测试算法:使用经验树计算错误率

(6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义

信息增益

        划分数据集的最大原则是:将无序的数据变得更加有序。我们采用量化的方法来度量数据的内容,组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。我们可以在划分数据前后使用信息论量化度量信息的内容。

        在划分数据集前后信息发生的变化成为信息增益,知道如何计算信息增益,就可以计算根据每个特征划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。

        集合信息的度量方式成为香农熵、信息熵或者简称为熵,熵在热力学中的物理意义是体系混乱程度的度量,而在信息科学中也可看成是信息混乱程度的度量,熵越大,信息越混乱。实际上熵定义为信息的期望值,要明确这个概念,就要知道信息的定义。如果待分类的事务可能划分在多个分类之中,则符号Xi的信息定义为

    ——其中p(xi)是选择该分类的概率

为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:

    ——其中n是分类的数目

计算熵:

# 计算给定数据集的香农熵(根据dataSet中所有的特征向量的类别计算熵)
# dataSet:给定数据集
# 返回shannonEnt:香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}            # 创建一个空字典
    # for循环:使labelCounts字典保存多个键值对,并且以dataSet中数据的类别(标签)为键,该类别数据的条数为对应的值
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():      # keys()方法返回字典中的键
            labelCounts[currentLabel] = 0               # 如果labelCounts中没有currentLabel,则添加一个以currentLabel为键的键值对,值为0
        labelCounts[currentLabel] += 1                  # 将labelCounts中类型为currentLabel值加1
    shannonEnt = 0.0
    for key in labelCounts:                             # 根据熵的公式进行累加
        prob = float(labelCounts[key])/numEntries       # 计算每种数据类别出现的概率
        shannonEnt -= prob * log(prob, 2)               # 根据定义的公式计算信息
    return shannonEnt

得到熵之后,我们就可以按照获取最大信息增益(熵减最多)的方法划分数据集。

划分数据集

        知道如何度量数据的无序程度,即如何计算信息熵之后,还需要进行的步骤是划分数据集,通过度量划分数据集的熵,以便判断当前是否正确地划分了数据集,因此需要对每个特征进行划分数据集,并且计算划分后的数据集信息熵,熵减最多时划分数据集所根据的特征就是划分数据集的最优特征。

划分数据集:

# 按照给定特征划分数据集
# dataSet:给定数据集
# axis:给定特征所在特征向量的列
# value:给定特征的特征值
# 返回retDataSet:划分后的数据集
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:                  # 若当前特征向量指定特征列(第axis列,列从0开始)的特征值与给定的特征值(value)相等
                                                    # 下面两行代码相当于将axis列去掉
            reducedFeatVec = featVec[:axis]         # 取当前特征向量axis列之前的列的特征
            reducedFeatVec.extend(featVec[axis+1:]) # 将上一句代码取得的特征向量又加上axis列后的特征
            retDataSet.append(reducedFeatVec)       # 将划分后的特征向量添加到retDataSet中
    return retDataSet

选择最好的划分方式:

# 选择最好的数据划分方式
# dataSet:要进行划分的数据集
# 返回bestFeature:在分类时起决定性作用的特征(下标)
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1       # 特征的数量
    baseEntropy = calcShannonEnt(dataSet)   # 计算数据集的香农熵
    bestInfoGain = 0.0                      # bestInfoGain=0:最好的信息增益为0,表示再怎么划分,
                                            # 香农熵(信息熵)都不会再变化,这就是划分的最优情况
    bestFeature = -1
    for i in range(numFeatures):            # 根据数据的每个特征进行划分,并计算熵,
                                            # 熵减少最多的情况为最优,此时对数据进行划分的特征作为划分的最优特征
        featList = [example[i] for example in dataSet] # featList为第i列数据(即第i个特征的所有特征值的列表(有重复))
        uniqueVals = set(featList)          # uniqueVals为第i列特征的特征值(不重复,例如有特征值1,1,0,0,uniqueVals为[0, 1])
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))  #
            newEntropy += prob * calcShannonEnt(subDataSet) # newEntropy为将数据集根据第i列特征进行划分的
                                                            # 所有子集的熵乘以该子集占总数据集比例的和
        infoGain = baseEntropy - newEntropy                 # 计算信息增益,即熵减
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

递归构建决策树

        到这里我们已经学习了从数据集构造决策树算法所需要的子功能模块,其工作原理如下:得到原始数据集,然后基于最好的特征划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。

        在这里递归的的结束条件是:程序遍历完所有划分数据集的属性(特征),即对数据的每个特征都进行了划分,或者每个分支下的所有数据都具有相同的类别标签。如果所有数据具有相同的类别标签,则得到一个叶子节点或者终止块。

        如果程序已经处理了数据集的所有特征,但是数据的类别标签依然不是唯一的,此时就需要决定如何定义该叶子节点,在这种情况下,通常会采用多数表决的方法决定该叶子节点的类别标签,即数据中那种类型的数据多,叶子结点就采用这个类别标签。

# 当数据集已经处理了所有属性,但是分类标签依然不唯一时,采用多数表决的方法决定叶子结点的分类
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

递归创建决策树:

# 利用函数递归创建决策树
# dataSet:数据集
# labels:标签列表,包含了数据集中所有特征的标签
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]    # 取出dataSet最后一列的数据

    if classList.count(classList[0]) == len(classList): # classList中classList[0]出现的次数=classList长度,表示类别完全相同,停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1:                            # 遍历完所有特征时返回出现次数最多的类别
        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet)        # 计算划分的最优特征(下标)
    bestFeatLabel = labels[bestFeat]                    # 数据划分的最优特征的标签(即是什么特征)
    myTree = {bestFeatLabel:{}}                         # 创建一个树(字典),bestFeatLabel为根结点
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)# 利用递归构造决策树
    return myTree

使用Python的Matplotlib注解绘制决策树

        从数据集中构造决策树之后,因为得到的结果是Python中的字典的表示形式,这中表示方式不直观,因此需要利用Matplotlib库来创建树形图,绘制出创建的决策树。

Matplotlib注解

使用Matplotlib绘制树形图的示例:

import matplotlib.pyplot as plt

# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

# 使用文本注解绘制树节点

def plotNode(nodeText, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeText, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

调用createPlot()函数:

构造注解树

        要绘制决策树就要知道决策树有多宽和有多深:

# 获取叶节点的数目,以便确定x轴的长度
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]     #根结点
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

# 获取决策树的深度
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

绘制决策树:

# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, textString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, textString)

def plotTree(myTree, parentPt, nodeText):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeText)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW;
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show();

绘制决策树时,调用createPlot()函数,传入要绘制的树结构。

测试和存储分类器

测试算法

# 使用决策树的分类函数
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

决策树的存储

# 使用pickle模块存储决策树
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

示例:使用决策树预测隐形眼镜类型

源码及相关数据下载

步骤:

(1)收集数据:提供的文本文件。

(2)准备数据:解析TAB键分隔的数据行。

(3)分析数据:快速检查数据,确保正确地解析数据内容,使用createPlot()函数绘制最终的树形图。

(4)训练算法:使用上面的createTree()函数。

(5)测试算法:编写测试函数验证决策树可以正确分类给定的数据实例。

(6)使用算法:存储树的数据结构,以便下次使用时无需重新构造树。

# 使用lenses.txt中的数据构造决策树
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = trees.createTree(lenses, lensesLabels)

# 画决策树
treePlotter.createPlot(lensesTree)

结果:

猜你喜欢

转载自blog.csdn.net/qq_32651225/article/details/72809551