Learning without Forgetting 论文阅读和对应代码详解

论文地址点这里

一. 介绍

基于终身学习思想,我们希望新任务可以分享旧任务的参数进行学习,同时不会出现灾难性遗忘。在这种场景下,需要开发一种简单有效的策略来解决各种分类任务。基于这种策略和单纯的想法,我们的CNN模型有一组共享参数 θ s \theta_s θs,在先前学习过的任务上的特定参数 θ o \theta_o θo,以及为新任务随机初始化特定的参数 θ n \theta_n θn。这样,对于一个网络 θ o 和 θ n \theta_o 和\theta_n θoθn可以看作是作用于参数 θ s \theta_s θs的特定分类器。针对这样的架构,有特定三种方式。
第一是特征提取: θ s \theta_s θs θ o \theta_o θo不改变,构建新的参数层 θ n \theta_n θn去处理新的任务。
第二种是微调: θ o \theta_o θo固定,用较小的学习率更新 θ s \theta_s θs θ n \theta_n θn。当然,也可以微调全连接层
第三种是联合训练: 所有参数联合优化,从每一个任务上交叉采样,这种方式可以认为是终身学习的上限。
这三种方法都有对应的缺点。直观上看,特征提取只会训练到新任务的参数,而卷积层和大部分连接层被固定,在新任务上取得的效果会较差。而对于微调,共享参数遭到改变,旧任务会出现灾难性遗忘。而联合训练则是偏离了我们的假定(旧任务不会出现)。
作者针对这些问题,提出了LwF模型。LwF和联合训练很像,只是不需要用到旧任务的数据。这种模型有以下的三个优点:
分类性能上,LwF优于特征提取,同时在新任务上进行微调,而在旧任务上使用微调参数也能获得较好的结果(就是能够学习又不遗忘)。
计算效率:训练时间比联合训练快,接近于微调,而测试时间比使用多个微调网络执行不同任务快。
部署简单:一旦了解任务,就不需要保留或重新应用训练数据保持自适应网络的性能(出现过的任务能够很好的维持性能)

二. 相关工作

LwF可以看作是蒸馏网络和微调的组合。微调是通过较低学习率,结合之前已有的网络以及相关数据信息处理新任务,以方便找到一个局部最小值。而蒸馏的思想是在一个更简单的网络中学习参数使得网络产生的输出与原始的大型网络产生的输出集合相同。LwF不同之处在于训练出的参数能够较好应对新旧任务,使用相同的数据监督新任务的学习,并对旧任务提供无监督的输出指导(意思是不通过重现旧任务数据而能通过参数调整保证旧任务学习较好)。

2.1 进行对比的方法

特征提取: 特征提取使用一个提前训练好的CNN去提取一个图片的特征。特征提取不修改原始网络,允许新任务从以前任务中学习到的复杂特征中获益。但这些特性并不是针对新任务的,通常可以通过微调来改进。
微调: 微调修改一个现有的CNN的参数来训练一个新的任务。输出层对新任务的权值进行随机初始化扩展,并使用较小的学习率调整所有参数的初始值,以最小化新任务的损失。一些时候,部分网络被冻结防止过拟合。共享参数会调整共享参数使他们对新任务更具有识别力,而低学习率是一种间接机制,以保留在原始任务上学习的表征结构。
多任务学习: 多任务学习通过结合来自所有任务的共同知识同时改善所有的任务,每一个任务为共享或约束的参数提供额外的训练数据,作为其他任务的正则化形式。多任务学习需要所有任务的数据,而对于LwF来说只需要新任务数据。
添加新节点: 在每一个网络层中添加新的节点是在学习新的判别特征的同时保持原有网络参数的一种方法。
(作者之后还介绍了一些相关的方法,仅仅只是说明和LwF的相关性,并没有解释这些算法,对大家理解LwF没什么帮助,所以就不写上去了)

三. LwF(Learning Without Forgetting)

在这里插入图片描述
(算法如上图所示)
输入:
首先对于一个新任务来说,我们模型已经有了在之前任务训练的出来的共享参数 θ s \theta_s θs和旧任务参数 θ n \theta_n θn,对于当前新到来的任务我们有了 X n 和 Y n X_n和Y_n XnYn
初始化:
首先是根据新任务数据拟合出来一个旧任务标签
Y o = C N N ( X n , θ s , θ o ) Y_o = CNN(X_n,\theta_s,\theta_o) Yo=CNN(Xn,θs,θo)
随机初始化: θ n \theta_n θn
训练:
旧任务预测出的结果:
Y o ′ = C N N ( X n , θ s ′ , θ o ′ ) Y_o'=CNN(X_n,\theta_s',\theta_o') Yo=CNN(Xn,θs,θo)
新任务预测结果:
Y n ′ = C N N ( X n , θ s ′ , θ n ′ ) Y_n'=CNN(X_n,\theta_s',\theta_n') Yn=CNN(Xn,θs,θn)
最后定义损失函数即可:
L o s s = λ 0 L o l d ( Y o , Y o ′ ) + L n e w ( Y n , Y n ′ ) + R ( θ s , θ o , θ n ) Loss = \lambda_0L_{old}(Y_o,Y_o')+L_{new}(Y_n,Y_n')+R(\theta_s,\theta_o,\theta_n) Loss=λ0Lold(Yo,Yo)+Lnew(Yn,Yn)+R(θs,θo,θn)
优化损失即可
这里值得注意的是对于旧任务预测,作者使用了蒸馏的方式,即
L o l d ( Y o , Y o ′ ) = − ∑ i = 1 l y o ′ ( i ) l o g y o ′ ( i ) L_{old}(Y_o,Y_o') = -\sum_{i=1}^{l}y_o'^{(i)}logy_o'^{(i)} Lold(Yo,Yo)=i=1lyo(i)logyo(i)
其中
y o ′ ( i ) = ( y o ′ ( i ) ) 1 / T ∑ j ( y o ′ ( i ) ) 1 / T y_o'^{(i)} = \frac{({y_o'^{(i)})}^{1/T}}{\sum_j({y_o'^{(i)})}^{1/T}} yo(i)=j(yo(i))1/T(yo(i))1/T
这里就是用蒸馏的方式进行改变。

四. 代码解析

代码github在这里,这个代码写的比较简单,适合进行学习用
首先,作者是按照两个类两个类的增加,也就是一个task对应两个类,选用的model为resnet34。
网络层:

class Model(nn.Module):
	def __init__(self, classes, classes_map, args):
		# Hyper Parameters
		self.init_lr = args.init_lr
		self.num_epochs = args.num_epochs
		self.batch_size = args.batch_size
		self.lower_rate_epoch = [int(0.7 * self.num_epochs), int(0.9 * self.num_epochs)] #hardcoded decay schedule
		self.lr_dec_factor = 10
		
		self.pretrained = False
		self.momentum = 0.9
		self.weight_decay = 0.0001
		# Constant to provide numerical stability while normalizing
		self.epsilon = 1e-16

		# Network architecture
		super(Model, self).__init__()
		self.model = models.resnet34(pretrained=self.pretrained)
		self.model.apply(kaiming_normal_init)

		num_features = self.model.fc.in_features
		self.model.fc = nn.Linear(num_features, classes, bias=False)
		self.fc = self.model.fc
		self.feature_extractor = nn.Sequential(*list(self.model.children())[:-1])
		#self.feature_extractor = nn.DataParallel(self.feature_extractor)


		# n_classes is incremented before processing new data in an iteration
		# n_known is set to n_classes after all data for an iteration has been processed
		self.n_classes = 0
		self.n_known = 0
		self.classes_map = classes_map

	def forward(self, x):
		x = self.feature_extractor(x)
		x = x.view(x.size(0), -1)
		x = self.fc(x)
		return x

	def increment_classes(self, new_classes):
		"""Add n classes in the final fc layer"""
		n = len(new_classes)
		print('new classes: ', n)
		in_features = self.fc.in_features
		out_features = self.fc.out_features
		weight = self.fc.weight.data

		if self.n_known == 0:
			new_out_features = n
		else:
			new_out_features = out_features + n
		print('new out features: ', new_out_features)
		self.model.fc = nn.Linear(in_features, new_out_features, bias=False)
		self.fc = self.model.fc
		
		kaiming_normal_init(self.fc.weight)
		self.fc.weight.data[:out_features] = weight
		self.n_classes += n

	def classify(self, images):
		"""Classify images by softmax

		Args:
			x: input image batch
		Returns:
			preds: Tensor of size (batch_size,)
		"""
		_, preds = torch.max(torch.softmax(self.forward(images), dim=1), dim=1, keepdim=False)

		return preds

特征提取层对应resnet34除最后一层全部,这里也就是共享参数。而旧任务参数对应一个全联接层,这里有个增加task参数的方法,也就是每次将最后一层全联接层添加新的参数(由2-4-6-8-…)。
在训练的时候,需要计算旧任务损失,代码如下:

if self.n_classes//len(new_classes) > 1:
	dist_target = prev_model.forward(images)
	logits_dist = logits[:,:-(self.n_classes-self.n_known)]
	dist_loss = MultiClassCrossEntropy(logits_dist, dist_target, 2)
	loss = dist_loss+cls_loss

dist_target就是由上一次模型计算出来的伪旧任务标签,logits_dist是用当前更新过的 θ s \theta_s θs θ n \theta_n θn计算出来的预测旧任务targets。
这里对应的方法是通过蒸馏的方式进行计算出来的,代码如下,T对应的就是蒸馏的温度(设置为2比较好)

def MultiClassCrossEntropy(logits, labels, T):
	# Ld = -1/N * sum(N) sum(C) softmax(label) * log(softmax(logit))
	labels = Variable(labels.data, requires_grad=False).cuda()
	outputs = torch.log_softmax(logits/T, dim=1)   # compute the log of softmax values
	labels = torch.softmax(labels/T, dim=1)
	# print('outputs: ', outputs)
	# print('labels: ', labels.shape)
	outputs = torch.sum(outputs * labels, dim=1, keepdim=False)
	outputs = -torch.mean(outputs, dim=0, keepdim=False)
	# print('OUT: ', outputs)
	return Variable(outputs.data, requires_grad=True).cuda()

最后优化更新即可。

猜你喜欢

转载自blog.csdn.net/qq_45478482/article/details/121482802