《DARTS:Differentiable Architecture Search》论文笔记

参考代码:darts

1. 概述

导读:在这篇文章之前的NAS文章很多是采用搜索空间与强化学习(或是演化算法)的组合,整个的网络的搜索流程是使用诸如policy gradient的方法产生结构优化梯度,期间也可以通过如ENAS的方式通过参数共享的方式加速网络搜索。但是更为直接的方式还是通过梯度优化的方式进行网络搜索,直接将离散结构优化问题转换为梯度问题,从而更加高效地进行求解(通过对边上操作的组合做softmax从而使得操作变得可微分,从而避免离散情况无法使用梯度下降)。在这篇文章中也是从整个网络节点构建一个超网络,之后在这个超网络中寻找最优子网络,并且为这个梯度下降求解过程进行了计算过程简化(进行了一阶和二阶梯度近似)从而进一步加快搜索速度,文章的方法在CIFAR-10数据集上获得了 2.76 ± 0.09 2.76\pm 0.09% 2.76±0.09的性能表现。

在文章中将节点数据定义为 x ( i ) x^{(i)} x(i),两个数据节点之间的边是 ( i , j ) (i,j) (i,j),这个边代表的是多个餐座的集合 o ( i , j ) o^{(i,j)} o(i,j)(如卷积/池化/NULL(文中为zero)等),那么从节点 i i i j j j的运算过程可以描述为:
x j = ∑ i < j o ( i , j ) ( x ( i ) ) x^{j}=\sum_{i\lt j}o^{(i,j)}(x^{(i)}) xj=i<jo(i,j)(x(i))
因而整个搜索的流程可以归纳为下图:
在这里插入图片描述
其中包含的步骤:

  • 1)首先定义一个搜索空间,图a;
  • 2)将计算节点通过超网络的形式构建一个大图,图b;
  • 3)通过梯度下降逐渐抽取出网络边中概率最大的操作,图c;
  • 4)筛选最大概率得到最后的网络结构,图d;

2 方法设计

2.1 网络搜索的数学模型

文章中网络搜索的可选操作集合描述为 O \mathcal{O} O,操作 o ( ⋅ ) o(\cdot) o()代表对数据节点 x ( i ) x^{(i)} x(i)采用了某种集合中的确定操作,为了使得整个超网络链接可微分,文章将其通过softmax构建为几个操作和的形式,从而对应的输出可以描述为:
o ‾ ( i , j ) ( x ) = ∑ o ∈ O e x p ( α o ( i , j ) ) ∑ o ‘ ∈ O e x p ( α o ‘ ( i , j ) ) o ( x ) \overline{o}^{(i,j)}(x)=\sum_{o\in\mathcal{O}}\frac{exp(\alpha_o^{(i,j)})}{\sum_{o^{‘}\in\mathcal{O}}exp(\alpha_{o^{‘}}^{(i,j)})}o(x) o(i,j)(x)=oOoOexp(αo(i,j))exp(αo(i,j))o(x)
其中,操作集合 O \mathcal{O} Osoftmax加权的部分其权值集合可以表示为 α = { α ( i , j ) } \alpha=\{\alpha_{(i,j)}\} α={ α(i,j)}。在完成搜索之后可以通过简单取最大概率的形式选择最后的网络结构 o ( i , j ) = arg max ⁡ o ∈ O α o ( i , j ) o^{(i,j)}=\argmax_{o\in\mathcal{O}}\alpha_o^{(i,j)} o(i,j)=oOargmaxαo(i,j)

接下来就是要在搜索空间中寻找最优的概率分布了,这里就设计到两个部分的优化:搜索空间本身自带的参数 w w w,以及边的概率集合 α \alpha α,它们也分别对应两个损失 L t r a i n , L v a l L_{train},L_{val} Ltrain,Lval。因而整体的搜索任务目标是在网络参数 w ∗ w^{*} w前提下通过最小化损失函数 L v a l ( w ∗ , α ∗ ) L_{val}(w^{*},\alpha^{*}) Lval(w,α)获得最佳的子网络结构采样 α ∗ \alpha^{*} α,其中 w ∗ w^{*} w是通过最小化训练损失 w ∗ = arg min ⁡ w L t r a i n ( w , α ∗ ) w^{*}=\argmin_wL_{train}(w,\alpha^{*}) w=wargminLtrain(w,α),具体描述为:
min ⁡ α L v a l ( w ∗ ( α ) , α ) \min_{\alpha}L_{val}(w^{*}(\alpha),\alpha) αminLval(w(α),α)
s . t .   w ∗ ( α ) = arg min ⁡ w L t r a i n ( w , α ) s.t.\ w^{*}(\alpha)=\argmin_{w}L_{train}(w,\alpha) s.t. w(α)=wargminLtrain(w,α)
需要注意的是上述的优化过程是一个递归优化过程,由于网络结构参数是一个高维数据,这就导致了整个优化过程变得困难,对此文章引入了一阶和二阶近似来进行简化。

2.2 优化梯度近似

对上文中的最优化目标函数求去梯度得到:
∇ α L v a l ( w ∗ ( α ) , α ) \nabla_{\alpha}L_{val}(w^{*}(\alpha),\alpha) αLval(w(α),α)
若是考虑了权重 w w w的更新过程,那么上面的梯度就可以描述为:
≈ ∇ α L v a l ( w − ξ ∇ w L t r a i n ( w , α ) , α ) \approx\nabla_{\alpha}L_{val}(w-\xi\nabla_wL_{train}(w,\alpha),\alpha) αLval(wξwLtrain(w,α),α)
其中, ξ \xi ξ是对应学习任务部分的学习率,在实际分析中发现参数 w w w的迭代优化过程其实是相当消耗资源的,那么一个直观的想法就是能不能对 w w w的优化只使用单个训练步骤就可以完成,这样就可以节省掉很大的计算开销。这个情况在网络参数在局部最优值的时候,其 ∇ w L t r a i n ( w , α ) = 0 \nabla_wL_{train}(w,\alpha)=0 wLtrain(w,α)=0,自然就不再需要对应优化过程了,只需要优化网络结构参数 α \alpha α就好了。则对于网络的优化过程可以描述为下面算法的步骤:
在这里插入图片描述
对上面的梯度进行链式法则展开的得到:
∇ α L v a l ( w ‘ , α ) − ξ ∇ α , w 2 L t r a i n ( w , α ) ∇ w ‘ L v a l ( w ‘ , α ) \nabla_{\alpha}L_{val}(w^{‘},\alpha)-\xi\nabla_{\alpha,w}^2L_{train}(w,\alpha)\nabla_{w^{‘}}L_{val}(w^{‘},\alpha) αLval(w,α)ξα,w2Ltrain(w,α)wLval(w,α)
其中, w ‘ = w − ξ ∇ w L t r a i n ( w , α ) w^{‘}=w-\xi\nabla_wL_{train}(w,\alpha) w=wξwLtrain(w,α)代表一次网络的前向,上面计算过程中很大的计算量在后面的部分中,则可以对上面的部分使用下面的式子进行近似逼近:
∇ α , w 2 L t r a i n ( w , α ) ∇ w ‘ L v a l ( w ‘ , α ) ≈ ∇ α L t r a i n ( w + , α ) − ∇ α L t r a i n ( w − , α ) 2 ϵ \nabla_{\alpha,w}^2L_{train}(w,\alpha)\nabla_{w^{‘}}L_{val}(w^{‘},\alpha)\approx\frac{\nabla_{\alpha}L_{train}(w^{+},\alpha)-\nabla_{\alpha}L_{train}(w^{-},\alpha)}{2\epsilon} α,w2Ltrain(w,α)wLval(w,α)2ϵαLtrain(w+,α)αLtrain(w,α)
经过上面逼近其计算过程大大降低只需要模型的两次前向和结构参数的两次后向传播,整体复杂度由 O ( ∣ α ∣ ∣ w ∣ ) O(|\alpha||w|) O(αw)减少为 O ( ∣ α ∣ + ∣ w ∣ ) O(|\alpha|+|w|) O(α+w)

对于上面提到的一阶和二阶近似文章中是使用学习率参数 ξ \xi ξ的取值来区分的:

  • 1)当其为0的时候因为没有了后面的部分,整体计算量大大降低,但是对应的效果就变差了,这个在文中的表1和表2中都有体现;
  • 2)当其大于0的时候就存在后面的梯度近似的部分,使得网络下降更加准确;

那么对学习率带来的影响进行分析,从下图可知通过选择合适的学习率可以帮助网络更快速收敛。
在这里插入图片描述

2.3 最后网络结构的生成

对于最后的网络文章中是选择边的集合中非zero的概率top-k操作保留,文章这样做处于如下的原因:

  • 1)保留两个操作是与之前的工作进行比较,因为之前的也是由两个输入;
  • 2)文章指出zero操作只是影响网络输出的scale,但是这个可以通过BN层抵消掉,所以选择非zero的操作;

3. 实验结果

CIFAR-10数据集:
在这里插入图片描述
ImageNet数据集:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/m_buddy/article/details/110499707
今日推荐