AIMET API 文档(6)
1.1.9 偏差校正 API
AIMET PyTorch 偏差校正 API
用户指南链接
要了解有关此技术的更多信息,请参阅跨层均衡
偏差校正API
aimet_torch.bias_correction.correct_bias(模型,quant_params,num_quant_samples,data_loader,num_bias_ Correct_samples,conv_bn_dict =无,perform_only_empirical_bias_corr = True,layers_to_ignore =无)
纠正模型每个 Conv 层的偏差(除非被忽略)。使用分析偏差校正和经验偏差校正的组合,即可以使用分析偏差校正来校正的所有层都使用分析偏差校正来校正,并且使用经验方法来校正剩余层。
返回就地校正的浮点模型
参数
- model ( Module) – 待修正模型
- quant_params ( QuantParams) – 用于偏差校正的量化模拟的命名元组
- num_quant_samples ( int) – 通过量化 sim 进行偏差校正的图像样本数。
- data_loader – 模型的数据加载器
- num_bias_ Correct_samples ( int) – 偏差校正的样本数
- conv_bn_dict ( Optional[ Dict[ Module, ConvBnInfoType]]) – 包含与激活相关信息的 conv 和 bn 的字典。如果没有,函数计算它
- Perform_only_empirical_bias_corr ( bool) – 默认 True。如果为 true,则仅对所有层执行经验偏差校正,无论该层是否符合分析偏差校正的条件。
- Layers_to_ignore ( Optional[ List[ Module]]) – 我们需要跳过偏差校正的层名称列表。
ConvBn信息类型
类 aimet_common.bias_correction.ConvBnInfoType(input_bn=None,output_bn=None,in_activation_type=<ActivationType.no_activation: 0>,out_activation_type=<ActivationType.no_activation: 0> )
使用 bn info 和激活类型保存转换的类型 支持的激活类型有 Relu 和 Relu6
参数
- input_bn – 参考层的输入 BatchNorm
- output_bn – 参考输出 BatchNorm 到层
- in_activation_type ( ActivationType) – 激活类型
- out_activation_type ( ActivationType) – 激活类型
激活类型
类 aimet_common.defs.ActivationType
用于识别激活类型的枚举
- no_activation= 0
没有激活 - relu= 1
ReLU激活 - relu6= 2
ReLU6 激活
量化参数
类 aimet_torch.quantsim.QuantParams(weight_bw=8,act_bw=8,round_mode=‘nearest’,quant_scheme=<QuantScheme.post_training_tf_enhanced:2>,config_file=None )
保存量化相关参数的数据类型。
构造函数
参数
- weight_bw ( int) – 用于量化层权重的权重位宽 (4-31)。默认 = 8
- act_bw ( int) – 用于量化层激活的激活位宽 (4-31)。默认 = 8
- round_mode ( str) – 舍入模式。支持的选项是“最近”或“随机”
- quant_scheme ( Union[ QuantScheme, str]) – 量化方案。支持的选项为“tf_enhanced”或“tf”或使用量化方案枚举 QuantScheme.post_training_tf 或 QuantScheme.post_training_tf_enhanced
- config_file ( Optional[ str]) – 模型量化器的配置文件路径
代码示例 #1 经验偏差校正
加载模型
model = MobileNetV2()
model.eval()
应用经验偏差校正
from aimet_torch import bias_correction
from aimet_torch.quantsim import QuantParams
params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')
# User action required
# The following line of code is an example of how to use the ImageNet data's validation data loader.
# Replace the following line with your own dataset's validation data loader.
data_loader = ImageNetDataPipeline.get_val_dataloader()
# Perform Empirical Bias Correction
bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
data_loader=data_loader, num_bias_correct_samples=512)
代码示例 #2 分析 + 经验偏差校正
加载模型
model = MobileNetV2()
model.eval()
查找 BN 和 Conv 模块
查找用于分析偏差校正的 BN + Conv 模块对和用于经验偏差校正的剩余 Conv 模块。
module_prop_dict = bias_correction.find_all_conv_bn_with_activation(model, input_shape=(1, 3, 224, 224))
应用分析+经验偏差校正
from aimet_torch import bias_correction
from aimet_torch.quantsim import QuantParams
params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest", quant_scheme='tf_enhanced')
# User action required
# The following line of code is an example of how to use the ImageNet data's validation data loader.
# Replace the following line with your own dataset's validation data loader.
data_loader = ImageNetDataPipeline.get_val_dataloader()
# Perform Bias Correction
bias_correction.correct_bias(model.to(device="cuda"), params, num_quant_samples=1000,
data_loader=data_loader, num_bias_correct_samples=512,
conv_bn_dict=module_prop_dict, perform_only_empirical_bias_corr=False)