机器学习算法梳理(三):决策树

决策树梳理

一、信息论基础

熵是用来衡量一个系统混论程度的物理量,代表一个系统中蕴含多少信息量,信息量越大表明一个系统不确定性就越大,就存在越多的可能性。

  • 信息熵便是信息的期望值,可以记作:
    在这里插入图片描述
  • 条件熵
    在这里插入图片描述
  • 信息增益
    在这里插入图片描述
  • 信息增益率
    在这里插入图片描述
  • 基尼指数
    -

二、决策树的不同分类算法

算法 支持模型 树结构 特征选择 连续值处理 缺失值处理 剪枝
ID3 分类 多叉树 信息增益 不支持 不支持 不支持
C4.5 分类 多叉树 信息增益比 支持 支持 支持
CART 分类/回归 二叉树 基尼系数,均方差 支持 支持 支持
  • ID3
    由于期望信息越小,信息增益越大,从而纯度越高,因此ID3算法的核心思想就是以信息增益度量属性选择,选择分裂后信息增益最大的属性进行分裂。
  • C4.5
    ID3算法存在一个问题,就是偏向于多值属性,例如,如果存在唯一标识属性ID,则ID3会选择它作为分裂属性,这样虽然使得划分充分纯净,但这种划分对分类几乎毫无用处。ID3的后继算法C4.5使用增益率(gain ratio)的信息增益扩充,试图克服这个偏倚。
  • CART
    ID3中根据属性值分割数据,之后该特征不会再起作用,这种快速切割的方式会影响算法的准确率。CART是一棵二叉树,采用二元切分法,每次把数据切成两份,分别进入左子树、右子树。而且每个非叶子节点都有两个孩子,所以CART的叶子节点比非叶子多1。相比ID3和C4.5,CART应用要多一些,既可以用于分类也可以用于回归。CART分类时,使用基尼指数(Gini)来选择最好的数据分割的特征,gini描述的是纯度,与信息熵的含义相似。CART中每一次迭代都会降低GINI系数。

三、回归树原理

在这里插入图片描述

四、决策树防止过拟合手段

  • 预剪枝(pre-pruning):预剪枝就是在构造决策树的过程中,先对每个结点在划分前进行估计,若果当前结点的划分不能带来决策树模型泛华性能的提升,则不对当前结点进行划分并且将当前结点标记为叶结点。
  • 后剪枝(post-pruning):后剪枝就是先把整颗决策树构造完毕,然后自底向上的对非叶结点进行考察,若将该结点对应的子树换为叶结点能够带来泛华性能的提升,则把该子树替换为叶结点。

具体可参考:https://blog.csdn.net/u012328159/article/details/79285214

五、模型评估

分类树的模型评估,可以用分类模型的评估指标。

  • Accuracy
  • Precision
  • Recall
  • F1
  • score
  • ROC曲线和AUC
  • PR曲线

回归树的模型评估,可以用回归模型的评估指标。

  • MSE
  • MAE
  • R2

六、sklearn参数详解,Python绘制决策树

scikit-learn决策树算法类库内部实现是使用了调优过的CART树算法,既可以做分类,又可以做回归。分类决策树的类对应的是DecisionTreeClassifier,而回归决策树的类对应的是DecisionTreeRegressor。两者的参数定义几乎完全相同,但是意义不全相同。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
 
iris = datasets.load_iris()
X = iris.data[:, [0,2]]
y = iris.target
 
# 训练模型,限制树的最大深度4
clf = DecisionTreeClassifier(max_depth=4)
# 拟合模型
clf = clf.fit(X, y)
 
# 画图
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1))
 
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
 
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
plt.show()

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/devcy/article/details/86096958
今日推荐