机器学习随笔3--决策树

摘要

  本文主要介绍决策树,以及决策树在数据集上的应用。

目录

一、问题描述

  配隐形眼镜是一件技术活,隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。经验丰富的眼科医生根据许多患者的眼球状况的观察条件和患者本身的身体条件,包括:年龄、处诊结果、散光情况、流泪量等,来推荐隐形眼镜类型。

  但对于新人医生来说,其经验未及老医生经验丰富。幸运的是,现在医院积累到了以前患者配隐形眼镜的一定量的数据。现在能否根据这些数据,基于患者情况来建立一个分类模型来帮助医生们来推荐患者应该佩戴的隐形眼镜类型呢?

二、数据获取与预处理

2.1 数据的获取

  在此从UCI数据库中选取的简化的隐形眼镜数据集lenses.txt,该隐形眼镜数据集是非常著名的数据集,它包含了很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型。
  数据集中包含有患者数据,每个患者的数据中包括患者的情况有:年龄、处诊结果、散光情况、流泪量等,以及对应的医生推荐的隐形眼镜类型有:硬材质、软材质以及不适合佩戴隐形眼镜。

2.2 数据预处理

  由于原数据中数据量较大,不利于本文决策树绘图的简便,所以现今只选择原数据集中的含年龄、处诊结果、散光情况、流泪量这四个指标的24个数据进行决策树分类。部分数据截图如图1所示。



图1

2.3 数据可视化

  我们见隐形数据集中的4个指标进行统计并用图表显示出来如图2所示。



图2

  可以看出,每个指标含有的属性至多为3个,此时使用决策树分类后,并绘制的决策树可以明显理清决策树的枝叶,同时决策树也就一目了然。

三、决策树分类算法原理

  
  决策树是处理分类问题的算法,将所要处理的数据看做是树的根,相应的选取数据的特征作为一个个节点(决策点),每次选取一个节点将数据集分为不同的数据子集,可以看成对树进行分支,这里体现出了决策,直到最后无法可分停止,也就是分支上的数据为同一类型,可以想象一次次划分之后由根延伸出了许多分支,形象的说就是一棵树。

  在机器学习中,决策树是一个预测模型,它代表的是对象属性与对象值之间的一种映射关系,我们可以利用决策树发现数据内部所蕴含的知识。

3.1 构造决策树

  基于之前的了解,在构造决策树首先需要选取特征将原始数据划分为几个数据集,那么第一个问题就是当前数据的哪个特征在划分数据分类时起决定性作用,所以必须评估每个特征。进而通过特征将原始数据就被划分为几个数据子集,这些数据子集分布在第一个决策点的所有分支上,如果分支上的所有数据为同一类型,则划分停止,若分支上的所有数据不是同一类型,则还需要继续划分,直到所有具有相同类型的数据均在一个数据子集中。
  在用决策树进行划分时,关键是每次划分时选取哪个特征进行划分,在划分数据时,我们必须采用量化的方法判断如何划分数据。
(1)信息增益
  划分数据时是根据某一原则进行划分,使得划分在同一集合中的数据具有共同的特征,据此,我们可以理解为划分数据的原则就是是无序的数据变得有序。当然划分数据有很多种方法,在此选用信息论度量信息,划分组织杂乱无章的数据。
  信息论是量化处理信息的分支科学,可以在数据划分之前或之后使用信息论量化度量信息的内容。其中在划分数据集之前之后信息发生的变化称为信息增益,计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
  首先我们需要知道怎么计算信息增益,集合信息的度量方式称为香农熵或者简称为熵,熵定义为信息的期望值,即:

H ( X ) = i = 1 n p ( x i ) log 2 1 p ( x i )

  熵越高,表明混合的数据越多,则可以在数据集中添加更多的分类。得到熵之后,我们就可以按照获取最大增益的办法划分数据集。
(2)划分数据集
  基于之前的分析,信息增益表示的是信息的变化,而信息可以用熵来度量,所以我们可以用熵的变化来表示信息增益。而获得最高信息增益的特征就是最好的选择,故此,我们可以对所有特征遍历,得到最高信息增益的特征加以选择。
  首先,我们按照给定特征划分数据集并进行简单的测试,接下来我们遍历整个数据集,循环计算香农熵,找到最好的划分方式并简单测试。
(3)递归构建决策树
  由于特征很可能多与两个,因此可能存在大于两个分支的数据集划分,第一次划分之后,可以将划分的数据继续向下传递,如果将每一个划分的数据看成是原数据集,那么之后的每一次划分都可以看成是和第一次划分相同的过程,据此我们可以采用递归的原则处理数据集。
  递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都有相同的分类。

3.2 使用matplotlib注解绘制树形图

  之前已经从数据集中成功的创建了决策树,但是字典的形式非常的不易于理解,因此本文采用Matplotlib库创建树形图。
(1)使用文本注解绘制树节点;
(2)获得叶节点的数目和树的层数;
(3)使用节点、层数等数据进行画图。

四、问题的计算

  基于之前的分析,我们知道可以根据决策树学习到眼科医生是如何判断患者需要佩戴的眼镜片,据此我们可以帮助人们判断需要佩戴的镜片类型。
  从隐形眼镜数据集中的很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型。我们选取此数据集,结合Matplotlib绘制树形图,进一步观察决策树是如何工作的。程序输出结果如图3所示:



图3

  同时,Matplotlib绘制树形图如图4所示:



图4

  可以看出,沿着决策树的不同分支,得到不同患者需要佩戴的隐形眼镜类型。同时,从该图中我们可以得到,只需要问四个问题就可以确定出患者需要佩戴何种隐形眼镜。

五、总结

5.1 算法的优点与缺点

(1)算法的优点:计算的复杂度不高,输出的结果易于理解,对中间值的确实不敏感,可以处理不相关的特征数据
(2)算法的缺点:可能会产生过度匹配的问题。

5.2 算法的拓展

  决策树其本质的思想是通过寻找区分度最好的特征(属性),用于支持分类规则的制定。本文主要使用的是ID3算法,自身也存在着很多不足。当然还有其它的决策树构造算法,比如C4.5和CART。

六、参考文献

[1]周志华.机器学习[M].北京:清华大学出版社,2016.
[2]Peter Harrington.机器学习实战[M].北京:人民邮电出版社,2013.
[3]韩家炜等.数据挖掘概念与技术[M].北京:机械工业出版社,2012.

七、附录

《机器学习实战》的代码,其代码的资源网址为:
  https://www.manning.com/books/machine-learning-in-action

  其中,trees.py文件为:

from math import log
import operator

def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers']
    #change to discrete values
    return dataSet, labels

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet: #the the number of unique elements and their occurance
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2) #log base 2
    return shannonEnt

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):        #iterate over all the features
        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
        uniqueVals = set(featList)       #get a set of unique values
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)     
        infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
        if (infoGain > bestInfoGain):       #compare this to the best gain so far
            bestInfoGain = infoGain         #if better than current best, set to best
            bestFeature = i
    return bestFeature                      #returns an integer

def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList): 
        return classList[0]#stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree                            

def classify(inputTree,featLabels,testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel

def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

treePlotter.py文件为:

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = ( list(myTree.keys()) )[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = ( list(myTree.keys()) )[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = ( list(myTree.keys()) )[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

#def createPlot():
#    fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()

def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

#createPlot(thisTree)

mian文件为:

from trees import *
from treePlotter import *
from numpy import *
import matplotlib
import matplotlib.pyplot as plt

"""
data,labels=createDataSet()
myTree=createTree( data , labels)
print(myTree)
createPlot(myTree)
"""

fr=open('lenses.txt')
data=[inst.strip().split('\t') for inst in fr.readlines()]
fr=open('lensesLabels.txt')
labels=fr.readlines()

myTree=createTree( data , labels)
print(myTree)
createPlot(myTree)

猜你喜欢

转载自blog.csdn.net/D_i_k_y/article/details/80948136