meta-learning在工业界的应用

meta learning直接翻译为“元学习”,也称为learning to learn。即让模型学会一种学习能力,如何通俗的理解呢?举个例子:上学的时候总是会感觉班里有个“学神”,学的快不用做很多题就能达到举一反三,一通百通的境界!这种快速适应新的问题并解决的能力我们称为“学习能力”。
万变不离其宗,假如我们让模型具有快速适应新数据。并利用少量的数据,尽可能少的epoch达到一种高精度的效果。这正是工业检测所需要的!因为工业场景的业务周期短,场景多。如果我们能够有一个通用的特征提取器,也就是训练一个基模型。具有快速适应新数据的能力,无疑是深度学习在工业界的一把利剑。
言归正传,下面会根据meta Learing的主要伪代码和公式来进行解析。
论文传送门:https://arxiv.org/pdf/1703.03400v3.pdf

数据集

这里对于数据集的要求是类别要多,样本少。这也是真实业务场景中的数据情况。这里我们选用的是omniglot数据集,共1623个类别,每个类别20张图片。
将1200个类别作为训练数据,423个类别作为测试数据。每次随机抽取5个类别,每个类别中随机抽取1张图片为x_spt,随机抽取15张图片为x_qry(训练模型适应能力的数据)。
在这里插入图片描述

meta learning的训练

为了更容易理解,引入meta和learner两个概念。可以想象理解为meta为老师,learner为学生。老师的目的是训练学生的自学能力,也就是对于新任务(新数据)能否快速学习。将模型对于新数据的测试结果 loss作为学习能力的评判依据。训练模型对于新数据的适应能力。

meta中的forward

 def forward(self, x_spt, y_spt, x_qry, y_qry):
        """

        :param x_spt:   [b, setsz, c_, h, w]
        :param y_spt:   [b, setsz]
        :param x_qry:   [b, querysz, c_, h, w]
        :param y_qry:   [b, querysz]
        :return:
        """
        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]  # losses_q[i] is the loss on step i
        corrects = [0 for _ in range(self.update_step + 1)]


        for i in range(task_num):

            # 1. run the i-th task and compute loss for k=0
            logits = self.net(x_spt[i], vars=None, bn_training=True)
            loss = F.cross_entropy(logits, y_spt[i])
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))

            # this is the loss and accuracy before first update
            with torch.no_grad():
                # [setsz, nway]
                logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q

                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] = corrects[0] + correct

            # this is the loss and accuracy after the first update
            with torch.no_grad():
                # [setsz, nway]
                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[1] += loss_q
                # [setsz]
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                # 1. run the i-th task and compute loss for k=1~K-1
                logits = self.net(x_spt[i], fast_weights, bn_training=True)
                loss = F.cross_entropy(logits, y_spt[i])
                # 2. compute grad on theta_pi
                grad = torch.autograd.grad(loss, fast_weights)
                # 3. theta_pi = theta_pi - train_lr * grad
                fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))

                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                # loss_q will be overwritten and just keep the loss_q on last update step.
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()  # convert to numpy
                    corrects[k + 1] = corrects[k + 1] + correct



        # end of all tasks
        # sum over all losses on query set across all tasks
        loss_q = losses_q[-1] / task_num

        # optimize theta parameters
        self.meta_optim.zero_grad()
        loss_q.backward()
        # print('meta update')
        # for p in self.net.parameters()[:5]:
        # 	print(torch.norm(p).item())
        self.meta_optim.step()


        accs = np.array(corrects) / (querysz * task_num)

        return accs

这里的task_num是meta中的batch size,参数update_step是learner更新的步数,使用的是 θ ‘ = θ − l r ∗ g r a d \theta^` = \theta - lr *grad θ=θlrgrad。每一次更新后的准确率会储存在corrects列表中,通过查看列表中准确率的增长速度可以用来评判learner的学习能力。

meta的训练,使用的是Adam优化器。loss是根据learner更新update_step之后的权重求得,这里使用的数据是y_qry来求取meta的loss。
工厂真实用来微调模型的流程是:采集易分错的数据,循环finetune模型。之后在进入到生产线中使用新数据来测试模型。这里的训练过程和真实场景一致。

详情代码:https://github.com/dragen1860/MAML-Pytorch

猜你喜欢

转载自blog.csdn.net/weixin_42662358/article/details/100543204