引言
决策树(DT,Decision Tree)是一种常用的有监督学习的分类算法,包括ID3、C4.5、C5.0、CART等种类。本文以ID3为例剖析决策树算法的原理和代码实现。
决策树进行分类的过程和程序语言中的if-else十分类似,甚至有萌新刚接触决策树的概念时会疑惑为什么不直接使用if-else来进行数据分类。两者区别在于,if-else是已知具体判断流程与条件的情况下采用的分类的方法;而决策树的使用场景是,你拥有大量已标注的训练样本数据,需要让机器从样本中学习出进行分类的判断条件、流程,然后再对实际数据进行分类。
一个简单的案例
给出如下的一组数据,共有14个样本,每个样本有天气(outlook)、气温(temperature)、湿度(humidity)、是否刮风(windy)四个属性,最后判断是否出去玩(play)。
outlook | temperature | humidity | windy | play |
---|---|---|---|---|
sunny | hot | high | False | no |
sunny | hot | high | True | no |
overcast | hot | high | False | yes |
rainy | mild | high | False | yes |
rainy | cool | normal | False | yes |
rainy | cool | normal | True | no |
overcast | cool | normal | True | yes |
rainy | mild | normal | False | yes |
sunny | mild | normal | True | yes |
overcast | mild | high | True | yes |
overcast | hot | normal | False | yes |
rainy | mild | high | True | no |
sunny | mild | high | False | no |
sunny | cool | normal | False | yes |
而决策树算法的核心在于,如何通过以上数据生成出一个最符合当前样本的树形决策流程。这里先给大家一个最终结果的直观的体验:
上图即是最符合当前14个样本的决策树,如果有测试样本,可以很容易地直接将按照决策树的流程和判断分支进行分类。
而决策树算法的重难点就是如何根据训练样本生成类似于上图这样的树。
决策树生成算法
先抛开很多类似于熵值、信息增益等专业名词,我们用最通俗的语言来描述一下决策树的生成。
回想一下在日常生活中,我们给事物进行分类时,可能会涉及一系列的判断依据,而首先用到的判断依据正是最重要的那个。比方说,让我们判断一个学生是不是好学生,我们第一反应是看他的考试成绩,其次才是出勤率、交作业次数等情况,并且越在后面考虑的因素,在判断中起到的影响越低。
之所以我们第一反应是考虑考试成绩,因为这个因素最能将好学生和差学生区分开:平均考试成绩小于70分的,基本就是差学生无疑了,不需要考虑其他因素;平均成绩大于90分的,基本上是好学生无疑了,也不需要考虑其他因素了;成绩不好不差的那一批不能直接区分,则才会需要用其他的因素来区分。
以上是决策树构建的直观认识,那么用更加科学的方法,决策树的根节点与后续节点的生成遵循什么样方法呢?
这里就要提到一些专业概念了:熵值和信息增益。
熵值
熵值用来表示系统的混乱程度,计算公式如下:
以上 为系统某个东西的出现概率。举个例子:
第一个数组有4个数字,每个数字互不相同,因此每一项在整个数组的出现概率是 ,因此熵值的计算为:
第二个数组也是4个数字,但是每个数字都一模一样,即每个数字的出现概率都是1,因此熵值为:
很明显第一个数组比第二个混乱程度高,因此计算出的熵值更高。特别地,第二个数组熵值为0,即完全有序。
可熵值在决策树的生成中能起到什么用处呢?
我们回到刚才提到区分好学生和差学生的例子,原本好学生和差学生都站在一块,这就是混乱程度高,即熵值高。在分类时我们用考试成绩作为首要的评判标准,用通俗的话讲,是因为这个因素最能将好学生和差学生分开。而用熵值的理论来表述,则是最能够让混乱程度(熵值)降低。
那么我们用什么指标来判断让熵值降低的程度呢?这里又要引出第二个概念:信息增益。
信息增益
信息增益的计算公式为:
其实就是系统改变前后的两个熵值相减,为了方便理解还是举个例子:
上面的数组被划分为了两个数组。
根据熵计算公式,改变前的熵值为:
改变后的我们需要分别计算两个数组的熵值,相加后再取平均数:
最后计算信息增益率:
以上就是信息增益的概念与计算方式。
回到案例
学习了熵值和信息增益后,我们回到开篇提到的根据天气(outlook)、气温(temperature)、湿度(humidity)、是否刮风(windy)来判断是否出去玩(play) 的案例。
outlook、temperature、humidity、windy四个属性哪一个应该最先被考虑,或者说哪一个属性应该作为决策树的根节点?
方法很简单,我们先计算总体样本的熵值,再分别计算用以上4个属性对总样本进行划分后的熵值,对比哪一种属性进行划分的信息增益高,就选哪个属性作为根节点。
先算总体样本的熵值,一共有14个样本,其中9个yes,5个no,因此我们可以计算:
分别计算用outlook、temperature、humidity、windy四个属性划分样本后的信息增益。
用outlook划分
用temperature划分
用humidity划分
用windy划分
比较信息增益
比较上面4个属性划分样本后的信息增益,可以得出用outlook进行划分的信息增益最大,因此outlook成为决策树的根节点。
确定根节点后,由于outlook有3种取值,因此样本又被划分为了3份,即根节点拥有3个孩子节点。其中overcast对应的节点样本全是yes,已经不必再分,但是另外两个孩子节点依然需要继续区分。后续我们就重复上面的操作,进一步划分孩子节点,最终形成决策树:
手写决策树算法
要实现决策树算法,首先我们要实现树的数据结构:
class Tree:
def __init__(self,label):
'''label代表划分本节点的特征名称,如outlook、humidity等
'''
self.label =label
self.child = {}
def add_child(self,key,value):
'''key代表特征的具体取值,如sunny、rainy、overcast
'''
self.child[key] =value
def list_child(self):
return self.child
然后实现决策树算法:
class DecisionTree:
@classmethod
def divide(cls, pdData, col):
''' 根据特征划分样本
'''
result = []
headers = pdData.columns.values
dic = cls.count(pdData, col=col)
for key in dic:
data = pdData[pdData[headers[col]] == key]
result.append(data)
return result
@classmethod
def entropy(cls, pdData):
'''当前节点熵值
'''
total_count = pdData.shape[0]
dic = cls.count(pdData, col=pdData.shape[1]-1)
# print(dic)
result = 0
for key in dic:
result -= (dic[key]/total_count)*np.log2(dic[key]/total_count)
return result
@classmethod
def entropy_if_divided(cls, pdData, char):
'''计算如果用某个特征划分样本后的熵值
'''
result = 0
d = cls.count(pdData, char)
# print(d)
for key in d:
result += cls.entropy(pdData[pdData[char] == key]) * \
((d[key])/pdData.shape[0])
return result
@classmethod
def gain(cls, pdData):
'''信息增益list
'''
l = []
headers = pdData.columns.values
for i in range(len(headers)-1):
l.append(cls.entropy_if_divided(pdData, headers[i]))
return cls.entropy(pdData)-l
@classmethod
def count(cls, pdData, char="", col=-1):
'''计算数据帧某一列有哪几种值以及对应的数量
'''
d = {}
if not char == "":
for item in pdData[char]:
if item in d:
d[item] = d[item]+1
else:
d[item] = 1
elif not col == -1:
for item in pdData.iloc[:, col]:
if item in d:
d[item] = d[item]+1
else:
d[item] = 1
return d
@classmethod
def generate(cls, pdData):
'''生成决策树
'''
headers = pdData.columns.values
max_index = np.argmax(cls.gain(pdData))
tree = Tree(headers[max_index])
li = cls.divide(pdData, col=max_index)
for item in li:
if item.shape[0] == 0: # 当前结点包含的样本集合为空,不能划分。
continue
if (cls.gain(item) == 0).all(): # 当前结点包含的样本全属于同一类别,无需划分
tree.add_child(item.iloc[0, max_index],
item.iloc[0, item.shape[1]-1])
continue
tree.add_child(item.iloc[0, max_index], cls.generate(item))
return tree
@classmethod
def classify_single(cls, data, headers, tree):
'''为单个样本分类
'''
label = tree.label
value = data[np.argwhere(headers == label)][0][0]
try:
child = tree.list_child()[value]
except KeyError as e:
return 0
if type(child) == Tree:
return cls.classify_single(data, headers, child)
else:
return child
@classmethod
def classify(cls, pdData, tree):
'''为输入的样本执行分类
'''
test_data = pdData.values
li = []
for item in test_data:
li.append(cls.classify_single(item, pdData.columns.values, tree))
return li
@classmethod
def accuracy(cls, prediction, reality):
'''评价指标准确率
'''
count = 0
arr = prediction == reality
for item in arr:
if(item == True):
count += 1
return count/len(arr)
@classmethod
def recall(cls, prediction, reality):
'''评价指标召回率
'''
count_tp = 0
count_fn = 0
for i in range(len(prediction)):
if(reality[i] == True and prediction[i] == True):
count_tp += 1
if(reality[i] == True and prediction[i] == False):
count_fn += 1
return count_tp/(count_tp+count_fn)
@classmethod
def precision(cls, prediction, reality):
'''评价指标准确度
'''
count_tp = 0
count_fp = 0
for i in range(len(prediction)):
if(reality[i] == True and prediction[i] == True):
count_tp += 1
if(reality[i] == False and prediction[i] == True):
count_fp += 1
return count_tp/(count_fp+count_tp)
测试代码
训练集:
outlook,temperature,humidity,windy,play
sunny,hot,high,FALSE,no
sunny,hot,high,TRUE,no
overcast,hot,high,FALSE,yes
rainy,mild,high,FALSE,yes
rainy,cool,normal,FALSE,yes
rainy,cool,normal,TRUE,no
overcast,cool,normal,TRUE,yes
rainy,mild,normal,FALSE,yes
sunny,mild,normal,TRUE,yes
overcast,mild,high,TRUE,yes
overcast,hot,normal,FALSE,yes
rainy,mild,high,TRUE,no
sunny,mild,high,FALSE,no
sunny,cool,normal,FALSE,yes
测试集:
outlook,temperature,humidity,windy,play
overcast,mild,high,TRUE,yes
sunny,hot,high,FALSE,no
sunny,cool,normal,True,yes
调用代码:
if __name__ == "__main__":
pdData_train=pd.read_csv("train.csv")
pdData_test=pd.read_csv("test.csv")
tree=DecisionTree.generate(pdData_train)
prediction=DecisionTree.classify(pdData_test,tree)
print(prediction)
输出结果:
[‘yes’, ‘no’, ‘yes’]