机器学习实战 决策树算法


# -- coding: utf-8 --

#from knn import*
from treePlotter import*
import matplotlib.pyplot as plt

myDat, labels = creatDataSet()
label = labels[:]               #复制labels列表 防止内容被改变  label = labels是引用 使用切片方式复制
print(myDat)
print(labels)
myTree = creatTree(myDat, labels)   #使用creatTree会改变labels列表内容
print(myTree)

filepath = r'E:\file\python\test\test\Tree_data\classifierstorage.txt'  #在地址路径前加个r,防止反斜杠


#配眼镜决策
glasspath = r'E:\file\python\test\test\Tree_data\lenses.txt'
fr = open(glasspath)
lenses = [inst.strip().split('\t') for inst in fr.readlines()]  #strip移除字符串头尾指定的字符(默认为空格)
lenseslables = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = creatTree(lenses, lenseslables)
print(lensesTree)
creatPlot(lensesTree)



# -- coding: utf-8 --
#treePlotter

from numpy import*
from math import log
import operator
import matplotlib.pyplot as plt
import pickle

def creatDataSet():
    dataSet = [[1,1,'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']  #特征标签
    return dataSet, labels

#信息熵函数
def clacShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}   #创造字典
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():  #key返回一个字典所有的键值
            labelCounts[currentLabel] = 0           #创建键值对
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2)   #信息熵公式累加
    return shannonEnt

#分割数据集
def splitDataSet(dataSet, axis, value):   #axis 特征value 特征值
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0])-1         #计算特征数
    baseEntropy = clacShannonEnt(dataSet)   #计算数据集的信息熵
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  #提取特征列表
        uniqueVals = set(featList)          #python的set是一个无序不重复元素集
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * clacShannonEnt(subDataSet)  #计算条件熵
        infoGain = baseEntropy - newEntropy   #信息熵增益
        if (infoGain > bestInfoGain):     #与最大增益比较
            bestInfoGain = infoGain
            bestFeature = i                #找出最大增益的特征
    return bestFeature

def majorityCnt(classList):   #多数表决
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
    classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(),
                              key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]

def creatTree(dataSet, labels):    #决策树生成 递归函数   label包含所有特征标签
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):   #统计字符串里某个字符出现的次数
        return classList[0]     #同一类 终止 返回类
    if len(dataSet[0]) == 1:   #表示特征用完了
        return majorityCnt(classList)   #使用多数表决法  终止函数返回最多的类
    bestFeat = chooseBestFeatureToSplit(dataSet)   #bestFeat为特征位数 0、1
    bestFeatLabel = labels[bestFeat]               #从特征标签中找到对应的特征名
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]   #最好特征的列表
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = creatTree(splitDataSet(dataSet, bestFeat, value), subLabels)
        # bestFeatLabel和value是key  后面赋值的是value
    return myTree


#获取决策树的深度和叶节点层数
def getNumleafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]       #决策树的key
    secondDict = myTree[firstStr]    #key对应的value 可能是个字典或者值
    for key in secondDict.keys():
        if type(secondDict[key]) == dict:    #若是字典 表示是个子树
            numLeafs += getNumleafs(secondDict[key])  #递归调用
        else:
            numLeafs += 1      #若是值 表示叶节点
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]) == dict:   #type()函数如果你只有第一个参数则返回对象的类型,dict表示判断是否是字典类型(不能加引号)
            thisDepth = 1 + getTreeDepth(secondDict[key])  #递归调用
        else:
            thisDepth = 1                               #叶节点的情况
        if thisDepth > maxDepth:                        #最大深度
            maxDepth = thisDepth
    return maxDepth


#绘制决策树
decisionNode = dict(boxstyle = "sawtooth", fc="0.8")  #dict创建字典
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

#绘制节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType): #文本注释函数 nodeTxt终点结点信息 parentPt起始地坐标 centerPT终点坐标 nodeType终点的框架类型
    creatPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',   #nodeTxt终点信息 xy起始坐标  xytext终点坐标
                           xytext = centerPt, textcoords = 'axes fraction',  #bbox 结点框架类型
                           va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)

#绘制中间信息
def  plotMidText(cntrPt, parentPt, txtString):   #计算父节点和子节点中间位置 放置中间文本信息即0或1
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    creatPlot.ax1.text(xMid, yMid, txtString)

#绘制树
def plotTree(myTree, parentPt, nodeTxt):   #树  父节点  节点信息
    numLeafs = getNumleafs(myTree)         #计算叶子节点个数
    depth = getTreeDepth(myTree)           #计算深度
    firstStr = myTree.keys()[0]       #第一个特征
    cntrPt = (plotTree.x0ff + (1.0+float(numLeafs))/2.0/plotTree.totalW, plotTree.y0ff)  #计算子节点位置
    plotMidText(cntrPt, parentPt, nodeTxt)    #计算信息位置 放置中间信息
    plotNode(firstStr, cntrPt, parentPt, decisionNode)  #绘制节点
    secondDict = myTree[firstStr]       #第一个key的value
    plotTree.y0ff = plotTree.y0ff - 1.0/plotTree.totalD   #纵坐标定位下降一个单位
    for key in secondDict.keys():
        if type(secondDict[key]) == dict :
            plotTree(secondDict[key], cntrPt, str(key))    #str返回一个对象的string格式
        else:
            plotTree.x0ff = plotTree.x0ff + 1.0/plotTree.totalW   #横坐标定位右移一个单位
            plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
            plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
    plotTree.y0ff = plotTree.y0ff + 1.0/plotTree.totalD


def creatPlot(inTree):
    fig = plt.figure(1, facecolor = 'white')  #创建一个当前画板
    fig.clf()       #清理当前figure
    axprops = dict(xticks = [], yticks = [])
    creatPlot.ax1 = plt.subplot(111,  frameon = False, **axprops)   #将当前画板分为1个绘画区域(axes),111表示将画板分为1行1列,并在第一个画板绘图
    # **表示接收的参数作为字典来处理
    plotTree.totalW = float(getNumleafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.x0ff = -0.5/plotTree.totalW
    plotTree.y0ff = 1.0
    plotTree(inTree, (0.5,1.0), '')   #根节点位置
    plt.show()


#使用决策树
def classify(inputTree, featlabels, testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featlabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if secondDict[key] == dict:
                classlabel = classify(secondDict[key], featlabels, testVec)
            else:
                classlabel = secondDict[key]
    return classlabel


#pickle模块存储决策树

#存储决策树
def storeTree(inputTree, filename):
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()

#取出决策树
def grabTree(filename):
    fr = open(filename)
    return pickle.load(fr)


猜你喜欢

转载自blog.csdn.net/fm904813255/article/details/80311913