机器学习---决策树分类代码

1. 计算数据集的香农熵

from numpy import *  
import numpy as np  
import pandas as pd  
from math import log  
import operator  
  
  
#计算数据集的香农熵  
def calcShannonEnt(dataSet):  
    numEntries=len(dataSet)  
    labelCounts={}  
    #给所有可能分类创建字典  
    for featVec in dataSet:  
        currentLabel=featVec[-1]  
        if currentLabel not in labelCounts.keys():  
            labelCounts[currentLabel]=0  
        labelCounts[currentLabel]+=1  
    shannonEnt=0.0  
    #以2为底数计算香农熵  
    for key in labelCounts:  
        prob = float(labelCounts[key])/numEntries  
        shannonEnt-=prob*log(prob,2)  
    return shannonEnt  

香农熵公式: 

数据集: 

2. 对离散变量划分数据集 

#对离散变量划分数据集,取出该特征取值为value的所有样本  
def splitDataSet(dataSet,axis,value):  
    retDataSet=[]  
    for featVec in dataSet:  
        if featVec[axis]==value:  
            reducedFeatVec=featVec[:axis]  
            reducedFeatVec.extend(featVec[axis+1:])  
            retDataSet.append(reducedFeatVec)  
    return retDataSet  

这个函数用于划分数据集。它的作用是从给定的数据集中,根据指定的特征和取值,提取出符合条

件的样本集合。函数的输入参数包括数据集(dataSet)、特征的索引(axis)和特征取值

(value)。在函数内部,通过遍历数据集中的每个样本(featVec),判断该样本在指定特征上的

取值是否与给定的取值相等。如果相等,则将该样本添加到结果集合(retDataSet)中。为了将样

本添加到结果集合中,需要先创建一个新的样本(reducedFeatVec),它是将原样本中指定特征

的取值去除后的结果。具体做法是通过切片操作将特征索引之前和之后的部分合并起来,形成新的

样本。最后,将新样本添加到结果集合中。最后,函数返回结果集合(retDataSet),其中包含了

所有符合条件的样本。

3. 对连续变量划分数据集

#对连续变量划分数据集,direction规定划分的方向,  
#决定是划分出小于value的数据样本还是大于value的数据样本集  
def splitContinuousDataSet(dataSet,axis,value,direction):  
    retDataSet=[]  
    for featVec in dataSet:  
        if direction==0:  
            if featVec[axis]>value:  
                reducedFeatVec=featVec[:axis]  
                reducedFeatVec.extend(featVec[axis+1:])  
                retDataSet.append(reducedFeatVec)  
        else:  
            if featVec[axis]<=value:  
                reducedFeatVec=featVec[:axis]  
                reducedFeatVec.extend(featVec[axis+1:])  
                retDataSet.append(reducedFeatVec)  
    return retDataSet  

这是一个用于划分连续变量数据集的函数。它接受四个参数:dataSet(数据集),axis(要划分

的特征的索引),value(划分的阈值),direction(划分的方向)。函数的作用是根据给定的方

向和阈值,将数据集划分为两个子集。如果direction为0,则将大于阈值的样本划分到一个子集

中;如果direction不为0,则将小于等于阈值的样本划分到一个子集中。

在函数的实现中,通过遍历数据集中的每个样本,根据给定的方向和阈值进行划分。如果样本的特

征值大于阈值且方向为0,将该样本的特征值从划分特征的位置上移除,并将剩余的特征值组成一

个新的样本,添加到划分后的子集中。如果样本的特征值小于等于阈值且方向不为0,同样进行相

同的操作。最后,返回划分后的子集。

4. 选择划分方式

#选择最好的数据集划分方式  
def chooseBestFeatureToSplit(dataSet,labels):  
    numFeatures=len(dataSet[0])-1  
    baseEntropy=calcShannonEnt(dataSet)  
    bestInfoGain=0.0  
    bestFeature=-1  
    bestSplitDict={}  
    for i in range(numFeatures):  
        featList=[example[i] for example in dataSet]  
      #  print(featList)
        #对连续型特征进行处理  
        if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':  
            #产生n-1个候选划分点  
            sortfeatList=sorted(featList)  
            splitList=[]  
            for j in range(len(sortfeatList)-1):  
                splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)  
              
            bestSplitEntropy=10000  
            slen=len(splitList)  
            #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点  
            for j in range(slen):  
                value=splitList[j]  
                newEntropy=0.0  
                subDataSet0=splitContinuousDataSet(dataSet,i,value,0)  
                subDataSet1=splitContinuousDataSet(dataSet,i,value,1)  
                prob0=len(subDataSet0)/float(len(dataSet))  
                newEntropy+=prob0*calcShannonEnt(subDataSet0)  
                prob1=len(subDataSet1)/float(len(dataSet))  
                newEntropy+=prob1*calcShannonEnt(subDataSet1)  
                if newEntropy<bestSplitEntropy:  
                    bestSplitEntropy=newEntropy  
                    bestSplit=j  
            #用字典记录当前特征的最佳划分点  
            bestSplitDict[labels[i]]=splitList[bestSplit]  
            infoGain=baseEntropy-bestSplitEntropy  
        #对离散型特征进行处理  
        else:  
            uniqueVals=set(featList)  
            newEntropy=0.0  
            #计算该特征下每种划分的信息熵  
            for value in uniqueVals:  
                subDataSet=splitDataSet(dataSet,i,value)  
                prob=len(subDataSet)/float(len(dataSet))  
                print(prob)
                newEntropy+=prob*calcShannonEnt(subDataSet)  
            infoGain=baseEntropy-newEntropy  
        if infoGain>bestInfoGain:  
            bestInfoGain=infoGain  
            bestFeature=i  
    #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理  
    #即是否小于等于bestSplitValue  
    if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':        
        bestSplitValue=bestSplitDict[labels[bestFeature]]          
        labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)  
        for i in range(shape(dataSet)[0]):  
            if dataSet[i][bestFeature]<=bestSplitValue:  
                dataSet[i][bestFeature]=1  
            else:  
                dataSet[i][bestFeature]=0  
    return bestFeature  

numFeatures=len(dataSet[0])-1:计算数据集中特征数量,减去1是因为最后一列通常是标签列。

baseEntropy=calcShannonEnt(dataSet):计算整个数据集的基本熵。

bestInfoGain=0.0:初始化最佳信息增益为0。bestFeature=-1:初始化最佳划分特征的索引为-1。

bestSplitDict={}:创建一个空字典,用于记录连续特征的最佳划分点。

遍历每个特征,featList=[example[i] for example in dataSet]:获取数据集中第i个特征所有取值。

if type(featList[0]).__name__=='float' or ... :判断特征是否为连续型特征。

sortfeatList=sorted(featList):对连续型特征的取值进行排序。

splitList=[]:创建一个空列表,用于存储候选划分点。

for j in range(len(sortfeatList)-1):遍历排序后的特征取值列表,生成n-1个候选划分点。

splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0):将相邻特征值的平均值作为候选划分点。

bestSplitEntropy=10000:初始化最佳划分点的信息熵为一个较大的值。

slen=len(splitList):获取候选划分点的数量。for j in range(slen):遍历每个候选划分点。

value=splitList[j]:获取当前候选划分点的值。newEntropy=0.0:初始化划分后的信息熵为0。

         subDataSet0=splitContinuousDataSet(dataSet,i,value,0):根据当前候选划分点将数据集划

分为小于等于该值的子集。subDataSet1=splitContinuousDataSet(dataSet,i,value,1):根据当前候

选划分点将数据集划分为大于该值的子集。

        prob0=len(subDataSet0)/float(len(dataSet)):计算小于等于划分点的子集在整个数据集中的

概率。newEntropy+=prob0*calcShannonEnt(subDataSet0):计算小于等于划分点的子集的信息

熵,并加权求和。prob1=len(subDataSet1)/float(len(dataSet)):计算大于划分点的子集在整个数

据集中的概率。newEntropy+=prob1*calcShannonEnt(subDataSet1):计算大于划分点的子集的

信息熵,并加权求和。

if newEntropy<bestSplitEntropy:如果划分后的信息熵小于当前最佳划分点的信息熵。

bestSplitEntropy=newEntropy:更新最佳划分点的信息熵。

bestSplit=j:记录当前最佳划分点的索引。

bestSplitDict[labels[i]]=splitList[bestSplit]:用字典记录当前特征的最佳划分点。

infoGain=baseEntropy-bestSplitEntropy:计算当前特征的信息增益。

       如果特征是离散型特征,uniqueVals=set(featList):获取特征的唯一取值。newEntropy=0.0:

初始化划分后的信息熵为0。遍历每个离散特征取值。subDataSet=splitDataSet(dataSet,i,value):

根据当前特征取值将数据集划分为子集。prob=len(subDataSet)/float(len(dataSet)):计算当前特征

取值的概率。newEntropy+=prob*calcShannonEnt(subDataSet):计算当前特征取值的信息熵,并

加权求和。infoGain=baseEntropy-newEntropy:计算当前特征的信息增益if infoGain >

bestInfoGain:如果当前特征的信息增益大于当前最佳信息增益。bestInfoGain=infoGain:更新最

佳信息增益。bestFeature=i:记录当前最佳划分特征的索引。

       如果当前最佳划分特征是连续型特征。bestSplitValue=bestSplitDict[labels[bestFeature]]:获

取当前最佳划分特征的最佳划分点labels[bestFeature] = labels[bestFeature] + '<=' + str

(bestSplitValue):将当前最佳划分特征的标签更新为带有最佳划分点的条件。遍历数据集中的每个

样本。if dataSet[i][bestFeature]<=bestSplitValue:如果当前样本的最佳划分特征的取值小于等于

最佳划分点。dataSet[i][bestFeature]=1:将当前样本的最佳划分特征的取值设置为1。如果当前样

本的最佳划分特征的取值大于最佳划分点。dataSet[i][bestFeature]=0:将当前样本的最佳划分特

征的取值设置为0。返回最佳划分特征的索引。

5. 递归构造决策树

#特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票  
def majorityCnt(classList):  
    classCount={}  
    for vote in classList:  
        if vote not in classCount.keys():  
            classCount[vote]=0  
        classCount[vote]+=1  
    return max(classCount)  
  
#主程序,递归产生决策树  
def createTree(dataSet,labels,data_full,labels_full):  
    classList=[example[-1] for example in dataSet]  
    if classList.count(classList[0])==len(classList):  
        return classList[0]  
    if len(dataSet[0])==1:  
        return majorityCnt(classList)  
    bestFeat=chooseBestFeatureToSplit(dataSet,labels)  
    bestFeatLabel=labels[bestFeat]  
    myTree={bestFeatLabel:{}}  
    featValues=[example[bestFeat] for example in dataSet]  
    uniqueVals=set(featValues)  
    if type(dataSet[0][bestFeat]).__name__=='str':  
        currentlabel=labels_full.index(labels[bestFeat])  
        featValuesFull=[example[currentlabel] for example in data_full]  
        uniqueValsFull=set(featValuesFull)  
    del(labels[bestFeat])  
    #针对bestFeat的每个取值,划分出一个子树。  
    for value in uniqueVals:  
        subLabels=labels[:]  
        if type(dataSet[0][bestFeat]).__name__=='str':  
            uniqueValsFull.remove(value)  
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels,data_full,labels_full)  
    if type(dataSet[0][bestFeat]).__name__=='str':  
        for value in uniqueValsFull:  
            myTree[bestFeatLabel][value]=majorityCnt(classList)  
    return myTree 

classList=[example[-1] for example in dataSet]:创建一个列表classList,其中包含数据集dataSet

中每个样本的类别标签。

if classList.count(classList[0])==len(classList):检查classList中的类别标签是否都相同。如果是,

则返回该类别标签作为叶子节点的类别。

if len(dataSet[0])==1:检查数据集dataSet是否只剩下一个特征。如果是,则返回classList中出现

次数最多的类别标签作为叶子节点的类别。

bestFeat=chooseBestFeatureToSplit(dataSet,labels):调用函数chooseBestFeatureToSplit,选择

最佳的特征进行划分,并将其索引保存在bestFeat中。

bestFeatLabel=labels[bestFeat]:根据bestFeat的索引,获取特征标签labels中对应的特征名称。

myTree={bestFeatLabel:{}}:创建一个字典myTree,以bestFeatLabel作为键,空字典作为值。这

个字典将用于构建决策树。

featValues=[example[bestFeat] for example in dataSet]:创建一个列表featValues,其中包含数据

集dataSet中每个样本在bestFeat特征上的取值。

uniqueVals=set(featValues):将featValues转换为集合uniqueVals,以获取bestFeat特征的唯一取

值。

if type(dataSet[0][bestFeat]).__name__=='str':检查bestFeat特征的数据类型是否为字符串。

如果是,则执行以下操作:

      currentlabel=labels_full.index(labels[bestFeat]):获取完整特征标签列表labels_full中labels

[bestFeat]的索引,并将其保存在currentlabel中;          

      featValuesFull=[example[currentlabel] for example in data_full]:创建一个列表

featValuesFull,其中包含完整数据集data_full中每个样本在currentlabel特征上的取值;         

      uniqueValsFull=set(featValuesFull):将featValuesFull转换为集合uniqueValsFull,以获取

currentlabel特征的唯一取值。

del(labels[bestFeat]):删除labels中索引为bestFeat的特征标签,因为该特征已经被用于划分。

for value in uniqueVals:对于uniqueVals中的每个取值,执行以下操作:

       subLabels=labels[:]:创建一个新的特征标签列表subLabels,并将labels的值复制给它。

       if type(dataSet[0][bestFeat]).__name__=='str':如果bestFeat特征的数据类型为字符串,执行

以下操作:uniqueValsFull.remove(value):从uniqueValsFull中移除当前取值value。 

        myTree[bestFeatLabel[value] =createTree(splitDataSet(dataSet,bestFeat,value),subLabels,

data_ full,labels_full):递归调用createTree函数,传入划分后的子数据集、子特征标签列表以及完

整数据集和特征标签列表,并将返回的子树存储在myTree中。

        if type(dataSet[0][bestFeat]).__name__=='str':如果bestFeat特征的数据类型为字符串,执行

以下操作:for value in uniqueValsFull::对于uniqueValsFull中的每个取值,执行以下操作:

myTree[bestFeatLabel][value]=majorityCnt(classList):将叶子节点的类别标签设置为classList中

出现次数最多的类别标签。

最后,返回构建好的决策树。

df=pd.read_csv('watermelon_3a.csv')  
data=df.values[:,1:].tolist()  
data_full=data[:]  
labels=df.columns.values[1:-1].tolist()  
labels_full=labels[:]  
myTree=createTree(data,labels,data_full,labels_full) 

6. 画树 

import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


#计算树的叶子节点数量
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else: numLeafs+=1
    return numLeafs

#计算树的最大深度
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else: thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
    return maxDepth

#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
    xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
    bbox=nodeType,arrowprops=arrow_args)

#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
    lens=len(txtString)
    xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
    yMid=(parentPt[1]+cntrPt[1])/2.0
    createPlot.ax1.text(xMid,yMid,txtString)
    
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
            plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
    plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD

def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.x0ff=-0.5/plotTree.totalW
    plotTree.y0ff=1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

plotNode函数用于绘制节点。它接受节点文本(nodeTxt)、中心点(centerPt)、父节点(parentPt)和节

点类型(nodeType)作为参数。在函数内部,它使用createPlot.ax1.annotate()函数来绘制节点文

本。

createPlot函数用于创建并显示一个图形。它接受一个树对象(inTree)作为参数。在函数内部,它创

建了一个图形对象(fig),清除了图形对象中的内容,然后创建了一个子图对象(createPlot.ax1)。接

下来,它调用了plotTree函数来绘制树的节点,并使用plt.show()显示图形。

plotMidText函数用于在箭头上绘制文字。它接受三个参数:cntrPt表示箭头的中心点坐标,

parentPt表示箭头的起始点坐标,txtString表示要绘制的文字。在函数内部,它计算了文字的位置

坐标,并使用createPlot.ax1.text()函数在图形上绘制文字。

plotTree函数用于绘制树的节点和箭头。它接受三个参数:myTree表示树的字典表示,parentPt表

示父节点的坐标,nodeTxt表示节点的文本。在函数内部,它首先获取树的叶子节点数和深度,然

后计算当前节点的位置坐标。接下来,它调用plotMidText函数在箭头上绘制文字,调用plotNode函

数绘制节点。然后,它遍历树的子节点,如果子节点是字典类型,则递归调用plotTree函数绘制子

树;如果子节点是叶子节点,则调用plotNode函数绘制叶子节点,并使用plotMidText函数在箭头上

绘制文字。最后,它更新plotTree.y0ff的值,以便绘制下一层的节点。

遇到的问题:createPlot.ax1 是什么意思?

在这句代码中,createPlot是函数类型(function),而createPlot.ax1是一个

matplotlib.axes._axes.Axes。createPlot.ax1是一个有效的变量名,而将其替换为

createPlot_ax1会导致报错。在代码中,createPlot.ax1是一个全局变量,用于引用子图对象。

功能有点类似于类的成员变量,为了共享createPlot.ax1。函数也是对象,给一个对象绑定一个属

性就是这样的:函数对象本身就有很多属性,__name____doc__等等。自己绑定的要有意义,没

意义的就不需要。

def f():
    pass

f.a = 1
print(f.a) 

# 1
createPlot(myTree)

猜你喜欢

转载自blog.csdn.net/weixin_43961909/article/details/132787698