【Keras模型量化】之量化感知训练 Quantization Aware Training(tfmot)

Quantization aware training 简介

Quantization aware training 即 量化感知训练,主要是使用tfmot.quantization.keras.quantize_model。

相较于训练后量化(可参考之前的博客:keras模型量化】之 训练后量化 Post-Training Quantization(TFLite)https://blog.csdn.net/u010637291/article/details/108649829),没那么易于使用,但量化感知训练在模型准确率方面的表现通常更好。

个人理解:不是针对参数进行量化,而是在构建模型后,针对模型结构进行量化。所以也不需要模型在量化后才能进行参数量化,可在创建keras模型后即进行模型结构量化。

Reference

[1] 量化感知训练 quantization aware training

概览:https://tensorflow.google.cn/model_optimization/guide/quantization/training

示例:https://tensorflow.google.cn/model_optimization/guide/quantization/training_example

量化自定义keras层:https://tensorflow.google.cn/model_optimization/guide/quantization/training_comprehensive_guide

[2] deepxi

Github: https://github.com/anicolson/DeepXi

API兼容性

支持模型:仅包含序贯模型和函数式模型的 tf.keras。
TensorFlow 版本:TF 2.x Nightly 版本。不支持包含 TF 2.X 软件包的 tf.compat.v1。
TensorFlow 执行模式:Eager Execution

安装命令:

pip uninstall -y tensorflow
pip install -q tf-nightly
pip install -q tensorflow-model-optimization

使用示例

如已有一keras sequential model:

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)

量化方式为:

import tensorflow_model_optimization as tfmot
#quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = tfmot.quantization.keras.quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()

实际遇到问题

不支持自定义层,也没有找到解决方案。如语音增强模型 deepxi中使用了自定义构建的Attention层。

猜你喜欢

转载自blog.csdn.net/u010637291/article/details/110276197
今日推荐