李航统计学习方法决策树id3例5.3

贴图:

这篇大部分参考机器学习实战上面的代码  只有部分函数有改动  因为。。。大改的地方总是调不出答案。。。。啊啊啊

上代码:

from math import log
def loadDataSet():
    dataSet = [['青年', '否', '否', '一般', '否'],
               ['青年', '否', '否', '好', '否'],
               ['青年', '是', '否', '好', '是'],
               ['青年', '是', '是', '一般', '是'],
               ['青年', '否', '否', '一般', '否'],
               ['中年', '否', '否', '一般', '否'],
               ['中年', '否', '否', '好', '否'],
               ['中年', '是', '是', '好', '是'],
               ['中年', '否', '是', '非常好', '是'],
               ['中年', '否', '是', '非常好', '是'],
               ['老年', '否', '是', '非常好', '是'],
               ['老年', '否', '是', '好', '是'],
               ['老年', '是', '否', '好', '是'],
               ['老年', '是', '否', '非常好', '是'],
               ['老年', '否', '否', '一般', '否']]
    label = ['年龄', '有工作', '有自己的房子', '信贷情况']
    return dataSet, label

def shannoEnt(dataSet):
    labelCount={}
    numOfData=len(dataSet)
    for data in dataSet:
        classify=data[-1]
        if classify not in labelCount.keys():
            labelCount[classify]=1
        else:labelCount[classify]+=1
    H=0.0
    for value in labelCount.values():
        pi=value/numOfData
        H-=pi*log(pi,2)
    return H

def chooseBestFeatureSplit(dataSet):
    HD=shannoEnt(dataSet)
    bestGain=0.0
    bestFeature=-1
    for i in range(len(dataSet[0])-1):
        feat=[data[i] for data in dataSet]
        prob=0.0
        for value in set(feat):
            subData=splitData(dataSet,i,value)
            prob+=(len(subData)/len(dataSet))*shannoEnt(subData)
        Gain=HD-prob
        if Gain>bestGain:
            bestGain=Gain
            bestFeature=i
    # print('最优特征是:',label[bestFeature])
    return bestFeature


def splitData(dataSet,axis,value):
    retDataSet=[]
    for data in dataSet:
        if data[axis]==value:
            reducedData=data[:axis]
            reducedData.extend(data[axis+1:])
            retDataSet.append(reducedData)
    return retDataSet

def majorityCnt(classList):
    classCount=dict([(i,classList.count(i)) for i in classList])
    return max(classCount,key=lambda x:classCount[x])


def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    if len(set(classList))==1:
        return classList[0]
    if len(dataSet[0])==0:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureSplit(dataSet)
    bestFeatLabel=labels[bestFeat]

    #创建树:
    myTree={bestFeatLabel:{}}
    del (labels[bestFeat])

    featValues=[example[bestFeat] for example in dataSet]
    # print(set(featValues))
    for value in set(featValues):

        subLabels=labels[:]
        subDataSet=splitData(dataSet,bestFeat,value)
        myTree[bestFeatLabel][value]=createTree(subDataSet,subLabels)
    return  myTree



dataSet,label=loadDataSet()
print(createTree(dataSet,label))

猜你喜欢

转载自blog.csdn.net/zuanfengxiao/article/details/79106347
今日推荐