贴图:
这篇大部分参考机器学习实战上面的代码 只有部分函数有改动 因为。。。大改的地方总是调不出答案。。。。啊啊啊
上代码:
from math import log def loadDataSet(): dataSet = [['青年', '否', '否', '一般', '否'], ['青年', '否', '否', '好', '否'], ['青年', '是', '否', '好', '是'], ['青年', '是', '是', '一般', '是'], ['青年', '否', '否', '一般', '否'], ['中年', '否', '否', '一般', '否'], ['中年', '否', '否', '好', '否'], ['中年', '是', '是', '好', '是'], ['中年', '否', '是', '非常好', '是'], ['中年', '否', '是', '非常好', '是'], ['老年', '否', '是', '非常好', '是'], ['老年', '否', '是', '好', '是'], ['老年', '是', '否', '好', '是'], ['老年', '是', '否', '非常好', '是'], ['老年', '否', '否', '一般', '否']] label = ['年龄', '有工作', '有自己的房子', '信贷情况'] return dataSet, label def shannoEnt(dataSet): labelCount={} numOfData=len(dataSet) for data in dataSet: classify=data[-1] if classify not in labelCount.keys(): labelCount[classify]=1 else:labelCount[classify]+=1 H=0.0 for value in labelCount.values(): pi=value/numOfData H-=pi*log(pi,2) return H def chooseBestFeatureSplit(dataSet): HD=shannoEnt(dataSet) bestGain=0.0 bestFeature=-1 for i in range(len(dataSet[0])-1): feat=[data[i] for data in dataSet] prob=0.0 for value in set(feat): subData=splitData(dataSet,i,value) prob+=(len(subData)/len(dataSet))*shannoEnt(subData) Gain=HD-prob if Gain>bestGain: bestGain=Gain bestFeature=i # print('最优特征是:',label[bestFeature]) return bestFeature def splitData(dataSet,axis,value): retDataSet=[] for data in dataSet: if data[axis]==value: reducedData=data[:axis] reducedData.extend(data[axis+1:]) retDataSet.append(reducedData) return retDataSet def majorityCnt(classList): classCount=dict([(i,classList.count(i)) for i in classList]) return max(classCount,key=lambda x:classCount[x]) def createTree(dataSet,labels): classList=[example[-1] for example in dataSet] if len(set(classList))==1: return classList[0] if len(dataSet[0])==0: return majorityCnt(classList) bestFeat=chooseBestFeatureSplit(dataSet) bestFeatLabel=labels[bestFeat] #创建树: myTree={bestFeatLabel:{}} del (labels[bestFeat]) featValues=[example[bestFeat] for example in dataSet] # print(set(featValues)) for value in set(featValues): subLabels=labels[:] subDataSet=splitData(dataSet,bestFeat,value) myTree[bestFeatLabel][value]=createTree(subDataSet,subLabels) return myTree dataSet,label=loadDataSet() print(createTree(dataSet,label))