QConfig 详解

QConfig 的定义

模型的量化方式由 qconfig 决定,在准备 qat / calibration 模型之前,需要先给模型设置 qconfig。

注意

因历史原因,Plugin 中有不同 qconfig 的定义和用法,早期版本的 qconfig 将在不久的将来被废弃,我们只推荐您使用此文档中介绍的 qconfig 用法。

一个 qconfig 对象可以设置 input / weight / output 三个关键字,分别表示算子输入/权重/输出的量化配置,prepare 模型时会根据这些配置决定是否要在对应位置插入 FakeQuantize / FakeCast 节点,None 表示不插入任何节点。

import torch from horizon_plugin_pytorch.quantization.qconfig import QConfig from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize from horizon_plugin_pytorch.quantization.fake_cast import FakeCast from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver from horizon_plugin_pytorch.dtype import qint8 qconfig = QConfig( input=None, weight=FakeQuantize.with_args( observer=MinMaxObserver, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, ), output=FakeCast.with_args(dtype=torch.float16), # activation=xxx 早期用法,作用与 output 关键字一致,当前仍兼容,但建议您使用 output 关键字。 )

FakeQuantize 的定义

FakeQuantize 是伪量化节点,会对输入进行量化反量化操作,插入伪量化可以在浮点模型的前向中模拟量化产生的误差。horizon_plugin_pytorch 支持 FakeQuantize / PACTFakeQuantize / _LearnableFakeQuantize 三种伪量化,我们只推荐您使用基于统计的 FakeQuantize,可以满足绝大部分需求。标准流程不对 PACTFakeQuantize 和 _LearnableFakeQuantize 两种方法做详细说明,如果一定有需求,请在阅读相关论文后再使用。

# 基于统计的方法 from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize # https://arxiv.org/pdf/1805.06085 from horizon_plugin_pytorch.quantization.pact_fake_quantize import PACTFakeQuantize # https://arxiv.org/pdf/1902.08153 from horizon_plugin_pytorch.quantization._learnable_fake_quantize import _LearnableFakeQuantize

可以调用 FakeQuantize 的 with_args 方法得到构造器,并按上一节的代码示例用它构造 qconfig。with_args 的参数包括 FakeQuantize 和 observer 支持配置的参数,理论上可以配置所有 FakeQuantize 和 observer 类 init 方法声明中的参数,但为了屏蔽无关紧要的细节,我们只推荐您配置 observer 相关参数。

不同 observer 的参数不同,下面列出常用 observer 构造 FakeQuantize 的例子,其他 observer 的具体用法见校准章节。

import torch from horizon_plugin_pytorch.quantization.qconfig import QConfig from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver, FixedScaleObserver, MSEObserver from horizon_plugin_pytorch.dtype import qint8 # MinMaxObserver 的 __init__ 方法包含很多参数,with_args 方法可以控制这些参数。 # 我们只推荐您设置 fq_constructor_1 示例中的几个参数。 # def __init__( # self, # averaging_constant: float = 0.01, # ch_axis: int = -1, # dtype: Union[torch.dtype, QuantDType] = qint8, # qscheme: torch.qscheme = torch.per_tensor_symmetric, # quant_min: int = None, # quant_max: int = None, # is_sync_quantize: bool = False, # factory_kwargs: Dict = None, # ) -> None: fq_constructor_1 = FakeQuantize.with_args( observer=MinMaxObserver, # 适用于 qat 阶段的 input / output / weight 和 calibration 阶段的 weight。 averaging_constant=0.01, # calibration 后进行 qat 时,可将 input / output 的 averaging_constant 置为 0 以固定 scale。 dtype=qint8, # 量化类型, 考虑算子的支持情况进行设置。 qscheme=torch.per_channel_symmetric, # 只有 weight 支持 per channel 量化。 ch_axis=0, # per channel 量化时指定 channel。 ) # 同理,您也可以查看 FixedScaleObserver 和 MSEObserver 的 __init__ 方法了解有哪些可以设置的参数。 fq_constructor_2 = FakeQuantize.with_args( observer=FixedScaleObserver, # 固定 scale,无论何种情况都不会变。 dtype=qint8, # 量化类型, 考虑算子的支持情况进行设置。 scale=INPUT_ABS_MAX / 128, # 设定的 scale 值, 一般设为绝对值最大值除以量化类型最大值。 ) fq_constructor_3 = FakeQuantize.with_args( observer=MSEObserver, # 适用于 calibration 阶段的 input / output。 dtype=qint8, # 量化类型, 考虑算子的支持情况进行设置。 ) qconfig = QConfig( weight=fq_constructor_x, ... )

FakeCast 的定义

FakeCast 是伪转换节点,会将输入转换为 float32 类型,如果数据类型是 float16,那么还会在中间模拟转 float16 产生的截断误差,此节点主要用于标志需要浮点计算的算子。

使用 FakeCast 构造 qconfig 的方法与 FakeQuantize 类似,但只有 dtype 一个参数。

import torch from horizon_plugin_pytorch.quantization.qconfig import QConfig from horizon_plugin_pytorch.quantization.fake_cast import FakeCast qconfig = QConfig( input=FakeCast.with_args(dtype=torch.float16), # 考虑算子的支持情况进行设置。 ... )

构造 QConfig

  1. 按照上文介绍的方法,直接构造 QConfig 对象。这种方法比较灵活,可以配置任何可配置的参数,需要您对 QConfig 有一定的理解。

  2. 使用 get_qconfig 接口。此接口较直接构造 QConfig 对象的方法更简单易用,但不够灵活,高级用法和需求无法使用此接口实现。

import torch from horizon_plugin_pytorch.quantization import get_qconfig from horizon_plugin_pytorch.quantization.observer_v2 import MinMaxObserver from horizon_plugin_pytorch.quantization.qconfig import QConfig from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize from horizon_plugin_pytorch.dtype import qint8 # qconfig_1 / qconfig_2 / qconfig_3 / qconfig_4 等价。 qconfig_1 = QConfig( weight=FakeQuantize.with_args( observer=MinMaxObserver, averaging_constant=0.01, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, ), output=FakeQuantize.with_args( observer=MinMaxObserver, averaging_constant=0, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, ), ) qconfig_2 = QConfig( weight=FakeQuantize.with_args( observer=MinMaxObserver, qscheme=torch.per_channel_symmetric, ch_axis=0, ), output=FakeQuantize.with_args( observer=MinMaxObserver, averaging_constant=0, ), ) qconfig_3 = get_qconfig( observer=MinMaxObserver, # 输入输出 observer 类型, 只支持 horizon_plugin_pytorch.quantization.observer_v2 中的 MinMaxObserver 和 MSEObserver, 默认值为 MinMaxObserver。 in_dtype=None, # 输入数据类型,考虑算子的支持情况进行设置。None 表示 QConfig 的 input 关键字为 None, 默认值为 None。 weight_dtype=qint8, # 权重数据类型,考虑算子的支持情况进行设置。None 表示 QConfig 的 weight 关键字为 None, 默认值为 qint8。 out_dtype=qint8, # 输出数据类型,考虑算子的支持情况进行设置。None 表示 QConfig 的 output 关键字为 None, 默认值为 qint8。 fix_scale=True, # 是否固定输入输出 scale。 ) qconfig_4 = get_qconfig(fix_scale=True)

使用 QConfig

  1. 直接设置 qconfig 属性。此方法优先级最高,其余方法不会覆盖直接设置的 qconfig。
model.qconfig = QConfig(...)
  1. qconfig 模板。在 prepare 接口上指定 qconfig setter 和 example_inputs,自动为模型设置 qconfig。
from horizon_plugin_pytorch.quantization import prepare from horizon_plugin_pytorch.quantization.qconfig_template import ( default_qat_qconfig_setter, ) qat_model = prepare( model, example_inputs=example_inputs, qconfig_setter=default_qat_qconfig_setter, )

QConfig 模板

qconfig 模板基于 subclass trace 方案感知模型的图结构,并按设定的规则自动设置 qconfig,是我们最推荐的设置 qconfig 方法。用法如下:

from horizon_plugin_pytorch.quantization import prepare from horizon_plugin_pytorch.quantization.qconfig_template import ( default_qat_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter ) qat_model = prepare( model, example_inputs=example_inputs, # 用来感知图结构 qconfig_setter=( # qconfig 模板,支持传入多个模板,优先级从高到低。 sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter(table, ratio=0.2), default_qat_qconfig_setter, ), )
注意

模板的优先级低于直接给模型设置 qconfig 属性,如果模型在 prepare 之前已经使用 model.qconfig = xxx 进行了配置,那么模板将不会生效。如果没有特殊需求,我们不推荐将两者混合使用,这很容易引发低级错误。绝大多数情况下,使用模板和 model.qconfig = xxx 两种设置方式中的一种即可满足需求。

模板可分为三类:

  1. 固定模板。固定模板中 calibration / qat / qat_fixed_act_scale 区别在于使用的 observer 类型和 scale 更新逻辑,分别用于校准,qat 训练,固定 activation scale qat 训练。default 模板( default_calibration_qconfig_setter / default_qat_qconfig_setter / default_qat_fixed_act_qconfig_setter )会做三件事:首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;然后,从 grid sample 算子的 grid 输入向前搜索,直到出现第一个 gemm 类算子或者QuantStub,将中间的所有算子都设置为 int16。根据经验这里的 grid 一般表达范围较宽,int8 有较大可能不满足精度需求;最后,将其余算子设置为 int8。int16 模板( qat_8bit_weight_16bit_act_qconfig_setter / qat_8bit_weight_16bit_fixed_act_qconfig_setter / calibration_8bit_weight_16bit_act_qconfig_setter )会做两件事:首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;其次,将其余算子设置为 int16。
from horizon_plugin_pytorch.quantization.qconfig_template import ( default_calibration_qconfig_setter, default_qat_qconfig_setter, default_qat_fixed_act_qconfig_setter, qat_8bit_weight_16bit_act_qconfig_setter, qat_8bit_weight_16bit_fixed_act_qconfig_setter, calibration_8bit_weight_16bit_act_qconfig_setter, )
  1. 敏感度模板。敏感度模板有 sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,三者的区别和固定模板中三者的区别一致,也是分别用于校准,qat 训练,固定 activation scale qat 训练。 敏感度模板的第一个输入是精度 debug 工具产生的敏感度结果,第二个参数可以指定 ratio 或 topk ,敏感度模板会将量化敏感度最高的 topk 个算子设置为 int16。搭配固定模板,可以轻松实现混合精度调优。
from horizon_plugin_pytorch.quantization.qconfig_template import ( default_calibration_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter, sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter, ) table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt") calibration_model = prepare( model, example_inputs=example_input, qconfig_setter=( sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table, ratio=0.2), default_calibration_qconfig_setter, ), )
  1. 自定义模板。自定义模板只有 ModuleNameQconfigSetter,需要传入模块名和对应 qconfig 的字典,一般用于设置 fixed scale 等特殊需求,可以和固定模板,敏感度模板搭配使用。
from horizon_plugin_pytorch.quantization.qconfig_template import ( default_qat_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter, ModuleNameQconfigSetter, ) table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt") module_name_to_qconfig = { "op_1": get_qconfig(), "op_2": QConfig( output=FakeQuantize.with_args( observer=FixedScaleObserver, dtype=qint16, scale=OP2_MAX/QINT16_MAX, ) ), } qat_model = prepare( model, example_inputs=example_input, qconfig_setter=( ModuleNameQconfigSetter(module_name_to_qconfig), sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2), default_qat_qconfig_setter, ), )