UA MATH574M 统计学习 Variable Selection:Cross Validation

UA MATH574M 统计学习 Variable Selection:Cross Validation

故事要从线性回归开始讲起。我们知道线性回归有很多优点,模型简单容易解释性质良好,但它在复杂一点的现实问题中表现很差(因为现实数据质量差、结构复杂、维数高等原因)。线性回归虽然能搞出无偏估计但方差很大,所以MSE也就不小了。现代统计为了在解决现实问题上有所突破,选择给估计量增加一些约束,放弃无偏性,试图让模型的MSE变小,这些约束放在损失函数中被称为惩罚项。一个很重要的问题是应该怎么确定惩罚项的系数,或者说怎么调参。首先有几个基本原则:1)数据是很贵的,所以调参要尽可能不浪费训练模型的数据;2)调参方法不能太繁琐,要在尽可能快的时间内完成调参;3)调参的过程最好可解释、可推广。如果数据真的很充裕,可以考虑在训练集和测试集之外再设置一个validation set用来调参;如果模型比较特殊,比如线性回归、时间序列模型等,可以推导一些理论的criteria来确定参数;除此之外最常用的是重抽样法调参,而重抽样法中最常用的是Cross-validation(CV)方法,所以这一讲主要聊Cross-validation。

考虑监督学习算法 Y = f ( X ) + ϵ Y=f(X)+\epsilon ,数据集为 { ( X i , Y i ) } i = 1 n \{(X_i,Y_i)\}_{i=1}^n ,我们要把这个数据集分成training set和validation set,分别用来训练模型和调参。

LOOCV

最早提出Leave-one-out(LOO)思想的是Mosteller and Tukey (1968),这个思想也在Allen (1971)中被用来计算PRESS(参考回归那个系列的博文UA MATH571A 多元线性回归II 变量选择)。它的idea很简单,就是每一次拿掉一个样本,用剩下的样本训练模型,预测拿掉的这个样本,计算validation error,最后通过最小化validation error来选择超参。如果被拿走的是第 i i 个,用剩下的 n 1 n-1 个样本训练出来的算法是 f ^ i \hat{f}_{-i} ,则validation error也被称为LOOCV score:
L O O C V = 1 n i = 1 n L ( Y i , f ^ i ( X i ) ) LOOCV =\frac{1}{n} \sum_{i=1}^n L(Y_i,\hat{f}_{-i}(X_{i}))
调参的思路是找能最小化LOOCV score的参数即可。关于LOOCV我们想了解两个问题:LOOCV score与监督学习算法的测试误差如何互相联系,根据LOOCV调参是否能保证算法一定具有不错的泛化能力?按LOOCV的定义,我们需要估计 n n 次模型才能计算出LOOCV score,这貌似有点麻烦,有没有更简单一点的不用每leave one out都要估计一下模型的计算方法?

LOOCV score的计算

先介绍一个看起来是废话的性质:
Leave-one-out Property 如果把 ( X i , f ^ i ( X i ) ) (X_i,\hat{f}_{-i}(X_i)) 添加到去掉的第 i i 个样本的位置,我们就又有了一个 n n 个样本的数据集,用它来训练监督学习算法,记为 f ~ i \tilde{f}_{-i} ,则
f ~ i ( X j ) = f ^ i ( X j ) , j = 1 , , n \tilde{f}_{-i}(X_j) = \hat{f}_{-i}(X_j),\forall j = 1,\cdots,n
这其实是说,如果把监督学习算法比作是贴近样本点的某种曲面,那么把曲面上除样本点以外的点再添加一个在样本里面,重新估计得到的新的曲面就是原来这个曲面。我们可以把算法输出看成 Y Y 的某种线性平滑:
f ^ ( X ) = S Y \hat{f}(X) = SY
其中 S S 是一个 n × n n\times n 的平滑矩阵,它与 X X 的取值有关; f ^ ( X ) = [ f ^ ( X 1 ) , , f ^ ( X n ) ] T \hat{f}(X) = [\hat{f}(X_1),\cdots,\hat{f}(X_n)]^T 。从而
f ^ ( X i ) f ~ i ( X i ) = S i i ( Y i f ^ i ( X i ) ) \hat{f}(X_i) -\tilde{f}_{-i}(X_i) =S_{ii}(Y_i-\hat{f}_{-i}(X_i))
Leave-one-out Property保证了
f ^ ( X i ) f ^ i ( X i ) = S i i ( Y i f ^ i ( X i ) ) f ^ i ( X i ) = f ^ ( X i ) S i i ( Y i f ^ i ( X i ) ) Y i f ^ i ( X i ) = Y i f ^ ( X i ) 1 S i i \hat{f}(X_i) -\hat{f}_{-i}(X_i) =S_{ii}(Y_i-\hat{f}_{-i}(X_i)) \\ \Rightarrow \hat{f}_{-i}(X_i) = \hat{f}(X_i) - S_{ii}(Y_i-\hat{f}_{-i}(X_i))\\ \Rightarrow Y_i - \hat{f}_{-i}(X_i) = \frac{Y_i-\hat{f}(X_i)}{1-S_{ii}}
第二行给出了避免做n次模型估计的代换方法;如果是平方损失,则根据第三行,
L O O C V = 1 n i = 1 n [ Y i f ^ ( X i ) 1 S i i ] 2 LOOCV = \frac{1}{n} \sum_{i=1}^n \left[ \frac{Y_i-\hat{f}(X_i)}{1-S_{ii}} \right]^2
这种计算方法只需要额外估计拟合值与 Y Y 的平滑矩阵 S S 的对角元。Luntz and Brailovsky(1969)证明了LOOCV validation error是几乎无偏的,它的期望等于 n 1 n-1 个样本训练的监督学习算法的预测误差的期望。Devroye(1978)证明了在分类任务中,当去掉一个样本训练出来的算法与用完整数据训练出来的算法非常接近时,LOOCV validation error的方差上界为 1 / n 1/n

K-fold CV

Multifold CV的思想最早是Geisser (1975)提出来的,它的思想是每次移走 d d d > 1 d>1 )个样本,这样一共就有 C n d C_n^d 种移法,用移走 d d 个剩下的样本作为训练集训练模型,然后在这 d d 个样本上计算validation error,最后做这 C n d C_n^d 个validation的简单平均作为最后的validation error,通过最小化这个validation error来调参。Zhang (1993)证明了这种方法与信息准则的渐近等价性。

K-fold CV的思想最早出现在Breiman et al. (1984)中,这种方法按组移动样本而不是针对单个或者多个样本进行移动。假设 n n 个样本被分为 K K 组,每一组具有数目相同的样本数,每一次用除第 k k 组外的其他数据训练模型,然后计算模型在第 k k 组上的validation error,重复 K K 次,计算这些validation error的简单平均作为最后的validation error。记 κ : { 1 , 2 , , n } { 1 , 2 , , K } \kappa:\{1,2,\cdots,n\}\to\{1,2,\cdots,K\} 记录样本与组数的对应关系,则
C V K = 1 n i = 1 n L ( Y i , f ^ κ ( i ) ( X i ) ) CV_K = \frac{1}{n}\sum_{i=1}^n L(Y_i,\hat{f}_{-\kappa(i)}(X_i))
如果取 κ ( i ) = i \kappa(i)=i ,则K-fold CV就变成LOOCV。Zhang (1993)指出用 K = 5 , 10 K=5,10 的效果一般是比较好的。

Generalized CV

Generalized CV(GCV)是对LOOCV的计算的进一步修正,平方损失下计算LOOCV score需要
L O O C V = 1 n i = 1 n [ Y i f ^ ( X i ) 1 S i i ] 2 LOOCV = \frac{1}{n} \sum_{i=1}^n \left[ \frac{Y_i-\hat{f}(X_i)}{1-S_{ii}} \right]^2
如果 S i i S_{ii} 们相差不大,我们就可以用他们的均值来做个近似:
S i i t r ( S ) / n S_{ii} \approx tr(S)/n
并定义基于这种近似计算的LOOCV score为GCV score
G C V = 1 n i = 1 n [ Y i f ^ ( X i ) 1 t r ( S ) / n ] 2 GCV = \frac{1}{n} \sum_{i=1}^n \left[ \frac{Y_i-\hat{f}(X_i)}{1-tr(S)/n} \right]^2
它可以进一步减少计算量,并且这个东西有更好的统计性质。

猜你喜欢

转载自blog.csdn.net/weixin_44207974/article/details/105526548