持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第18天,点击查看活动详情
能将几个弱学习器结合成一个强学习器的任意集成方法,被称为提升法。而大部分的提升法思路都是循环训练多个预测器,新的预测器不断对前一个预测器进行修正。当前比较流行的方法有AdaBoost和Gradient Boosting,下面将详细讲讲这两个方法:
AdaBoost
AdaBoosting技术实现的是:新的训练器会对前一个训练器的欠拟合的数据进行重新预测,聚焦于不断解决那些困难的问题。他的流程大致如下图1:
图1 基于AdaBoosting循环
使用AdaBoosting技术时,我们需要先设置并训练一个基础分类器,然后开始第一次对训练集的预测。然后对分类错误的训练实例增加相对权重。之后将新的训练集放于第二个训练器中,按照上面次序不断循环设置的全部预测器。
ok,让我们来仔细看看AdaBoosting的算法中的更新权重的公式,我们首先需要设置每个实例权重 的初始值为
m=len(X_train)
sample_weights = np.ones(m) / m
接下来我们对实例训练后,需要计算出加权误差率 :
r = sample_weights[y_pred != y_train].sum() / sample_weights.sum()
接下来计算权重 :
alpha = learning_rate * np.log((1 - r) / r)
是需要我们设置的超参数-学习率。可见预测的越准确,权重就会越高。错的越多,则权重还可能为负。
最后更新权重并进行归一化(就是对所有的实例除以 ),更新的规则如下所示:
sample_weights[y_pred != y_train] *= np.exp(alpha)
sample_weights /= sample_weights.sum()
最后得出的预测就是将之前所有的预测器之间的预测加权( )后选择获得投票最多的:
也可以用Scikit-Learn的AdaBoostingClassifier,这是一个多分类的版本,叫做SAMME(基于多类指数损失函数的逐步添加模型),当他只有两类时才是Adaboosting。如果我们的预测器含有predict_proba()方法,我们也可以将AdaBoostingClassifier的algorithm改为"SAMME.R"这个是基于类的预测,通常效果更好。
Gradient Boosting
与AdaBoost不同的是,Gradient Boosting他在新的预测器对前一个预测器的残差进行拟合。大概的效果如下:
# 获得初始的
y1 = xxx1_reg.predict(X)
# 获取预测值和标签的差值
y2 = y1 - xxx2_reg.predict(X)
# 继续获取预测值和标签的差值
y3 = y2 - xxx3_reg.predict(X)
# 预测值就是所有预测器的预测之和
y_pred = sum( reg.predict(X_test) for reg in (xxx1_reg,xxx2_reg,xxx3_reg) )
复制代码
依然在Scikit-Learn可以使用GradientBoostingRegressor(这里使用的预测器和随机森林一样也用了决策树)来进行回归任务。我们需要注意几个超参数:
learning_rate
-这个参数是对每棵树的贡献进行缩放的,如果设置的比较低,就需要多一些预测器,同时泛化的效果也更好些。n_estimatirs
-我们设置的预测器数量- 同时我们也可以根据需要设置决策树的超参数
我们来看一个案例,随着书的数量增加时,误差的变化:
图2 误差变化
我们看到了其实误差在大概60不到点出就出现了最低值,后面其实增加再多的树都无法再降低误差,但是由于设置了固定的数目,所以程序一直执行到设置的树的数量。
其实我们如果不知道大概的树的数量,我们可以使用循环一点点增加预测器的数目,然后不断判断误差是否再减少,当不在减少到比如5次是就不在训练了:
#最大树的数量
max_trees_num = xxx
for n_estimators in range(1,max_trees_num):
# 每次修改GradientBoostingRegressor的预测器数量
gbrt.n_estimators = n_estimators
# ... 省略非核心代码
# 判断误差是否在减少,如果没有减少则nondecreasing_count加1
# nondecreasing_count 累加到5 就停止训练
if new_error < min_error:
min_error = new_error
nondecreasing_count = 0;
else:
nondecreasing_count += 1
if nondecreasing_count == 5:
break;
复制代码