关于使用HAWQ量化框架执行训练前推理的性能崩溃问题

问题描述

今天debug量化模型遇到一个比较奇怪的问题,之前从来没有注意过:
现在量化模型的流程是:
1)加载预训练好的浮点数权重模型;
2)将模型架构替换成量化架构(逐模块替换),此时浮点权重会直接被fake quant;
3)训练fake quant后的模型权重(QAT)过程。

直接在第1)步后面验证浮点权重模型效果是完全没问题的:

validate(val_loader, model, criterion, args)

我现在希望在第2)步后面验证权重直接被fake quant后的模型效果,发现所有的sample上推理结果都是0%,即模型失败!
甚至只量化第一层的激活值,也会出现模型失败的情况。于是debug了很久量化函数,并没有发现任何问题。基本的代码思路就是如下图这样(没有act_range_momentum的简单版本):

在这里插入图片描述

问题分析

最后发现问题出在validate函数中,HAWQ框架的validate函数加上了一行freeze_model(model)函数,其作用是跑量化函数时不计算模型的self.min和self.max(默认为0),导致的后果是scale算出来非常非常小,量化值(x/scale)非常非常大,clamp后所有值都被截到了两个表示范围的边缘,变成了-128或127。因此模型就失败了。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

于是模型过量化函数时不会进self.running_stat分之,所以也不会计算self.min和self.max,导致模型失败。

在这里插入图片描述

所以HAWQ中使用validate应该是针对在做/做过量化训练的模型而言的,因此此时就保存好了self.min和self.max值,推理没问题。而对于我们这种情况(也就是加载浮点模型,然后直接过量化验证效果),是没有初始self.min和self.max值的,导致出错。

解决方案

  1. 既然想训练前验证一下,最简单的方法是直接注释掉freeze_model(model),然后推理结果就正常了

  2. 当然也可以给模型一个warmp/初始化的阶段去算最初的self.min和self.max:参考 WinogradAwareNets

感谢和Tengyu、Chenqi共同的讨论和debug,一起总结出了一些宝贵的经验(我的知识漏洞)。