Prepare in Detail

Definition of Prepare

Prepare is the process of converting a floating-point model into a pseudo-quantized model. This process involves several key steps:

  1. Operator Replacement: Some torch function operators (such as F.interpolate) need to have FakeQuantize nodes inserted during quantization. Therefore, these operators are replaced with corresponding Module type implementations (horizon_plugin_pytorch.nn.Interpolate) to place the FakeQuantize nodes inside this Module. The model before and after replacement is equivalent.

  2. Operator Fusion: BPU supports fusing specific computational patterns, where the intermediate results of fused operators are represented with high precision. Therefore, we replace multiple operators to be fused with a single Module to prevent quantizing the intermediate results. The model before and after fusion is also equivalent.

  3. Operator Conversion: Floating-point operators are replaced with QAT (Quantization-Aware Training) operators. According to the configured qconfig, QAT operators will add FakeQuantize nodes at the input/output/weights.

  4. Model Structure Check: The QAT model is checked, and a check result file is generated.

Attention

Due to historical reasons, there are two early interfaces, prepare_qat and prepare_qat_fx, in the Plugin. These will gradually be deprecated. We only recommend using the prepare interface introduced in this document.

The usage of the prepare interface is as follows:

from horizon_plugin_pytorch.quantization.prepare import prepare, PrepareMethod from horizon_plugin_pytorch.quantization.qconfig_template import ( default_qat_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter, ) # When using templates, example_inputs and qconfig_setter must be provided. # When method is PrepareMethod.JIT_STRIP or PrepareMethod.JIT, example_inputs must be provided. # def prepare( # model: torch.nn.Module, # example_inputs: Any = None, # used to get model's graph structure, ensuring it can be used to run forward. # qconfig_setter: Optional[Union[Tuple[QconfigSetterBase, ...], QconfigSetterBase]] = None, # qconfig template, supports multiple templates, priority from high to low. # method: PrepareMethod = PrepareMethod.JIT_STRIP, # prepare method # ) -> torch.nn.Module: qat_model = prepare( float_model, example_inputs=example_inputs, qconfig_setter=( sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2), default_qat_qconfig_setter, ), method=PrepareMethod.JIT, )

PrepareMethod

There are four prepare methods, compared as follows:

methodPrincipleAdvantagesDisadvantages
PrepareMethod.JIT & PrepareMethod.JIT_STRIPUse hooks and subclass tensor to get the graph structure, performing operator replacement/operator fusion on the original forward.Fully automatic, minimal code modification, hides many detail issues, easy to debug.Dynamic code blocks need special handling.
PrepareMethod.SYMBOLICUse symbolic trace to get the graph structure, performing operator replacement/operator fusion on the recompiled new forward.Fully automatic, hides many detail issues.Does not support dynamic control flow, some data types and Python operations, less convenient for debugging.
PrepareMethod.EAGERDoes not sense the graph structure. operator replacement/operator fusion needs to be done manually.Flexible usage, controllable process, easy to debug and handle various special needs.Requires more manual operations, more code modifications, high learning cost.

Currently, JIT and JIT_STRIP are our recommended methods. The difference between them is that JIT_STRIP will identify and skip pre-process and post-process based on the positions of QuantStub and DequantStub in the model. Therefore, if there are pre-process and post-process steps in the model that do not need to be quantized, use JIT_STRIP. Otherwise, they will be quantized. Apart from this difference, they are completely identical. SYMBOLIC and EAGER are earlier solutions with many usability issues. We do not recommend using these two methods.

Example

import copy import numpy as np import torch from torch import nn from torch.nn import functional as F from torch.quantization import DeQuantStub, QuantStub from horizon_plugin_pytorch import March, set_march from horizon_plugin_pytorch.fx.jit_scheme import Tracer from horizon_plugin_pytorch.quantization import ( FakeQuantState, get_qconfig, PrepareMethod, prepare, set_fake_quantize, ) class Net(torch.nn.Module): def __init__(self, input_size, class_num) -> None: super().__init__() self.quant0 = QuantStub() self.quant1 = QuantStub() self.dequant = DeQuantStub() self.conv = nn.Conv2d(3, 3, 1) self.bn = nn.BatchNorm2d(3) self.classifier = nn.Conv2d(3, class_num, input_size) self.loss = nn.CrossEntropyLoss() def forward(self, input, other, target=None): # Preprocess that does not need quantization. Use JIT_STRIP to exclude these operations from the computational graph. input = input - 128 / 128.0 x = self.quant0(input) y = self.quant1(other) n = np.random.randint(1, 5) # Since the python code is not regenerated, this dynamic loop is retained in the QAT model. for _ in range(n): # Dynamic code blocks involving operator replacement or fusion must be marked with Tracer.dynamic_block(self, "ConvBnAdd"): x = self.conv(x) x = self.bn(x) x = x + y x = self.classifier(x).squeeze() # Since the python code is not regenerated, this dynamic control flow is retained in the QAT model if self.training: assert target is not None x = self.dequant(x) return F.cross_entropy(torch.softmax(x, dim=1), target) else: return torch.argmax(x, dim=1) model = Net(6, 2) train_example_input = ( torch.rand(2, 3, 6, 6) * 256, torch.rand(2, 3, 6, 6), torch.tensor([[0.0, 1.0], [1.0, 0.0]]), ) eval_example_input = train_example_input[:2] model.eval() set_march(March.NASH_E) model.qconfig = get_qconfig() qat_model = prepare( model, example_inputs=copy.deepcopy(eval_example_input), method=PrepareMethod.JIT_STRIP, ) qat_model.graph.print_tabular() # opcode name target args kwargs # ------------- ---------------- --------------------------------------------------------- -------------------------------- ---------- # placeholder input_0 input_0 () {} # call_module quant0 quant0 (input_0,) {} # placeholder input_1 input_1 () {} # call_module quant1 quant1 (input_1,) {} # call_module conv conv (quant0,) {} # call_module bn bn (conv,) {} # get_attr _generated_add_0 _generated_add_0 () {} # call_method add_2 add (_generated_add_0, bn, quant1) {} # scope_end is automatically inserted during the trace process to mark the boundaries of sub-modules or dynamic code blocks, not corresponding to any calculations # call_function scope_end <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {} # call_module conv_1 conv (add_2,) {} # call_module bn_1 bn (conv_1,) {} # get_attr _generated_add_1 _generated_add_0 () {} # call_method add_3 add (_generated_add_1, bn_1, quant1) {} # call_function scope_end_1 <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {} # call_module classifier classifier (add_3,) {} # call_function squeeze <method 'squeeze' of 'torch._C._TensorBase' objects> (classifier,) {} # call_function argmax <built-in method argmax of type object at 0x7f66f04cf820> (squeeze,) {'dim': 1} # call_function scope_end_2 <function Tracer.scope_end at 0x7f65d90e5e50> ('',) {} # output output output ((argmax,),) {} print(qat_model) # GraphModuleImpl( # (quant0): QuantStub( # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (quant1): QuantStub( # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (dequant): DeQuantStub() # (conv): Identity() # Since the forward code remains unchanged, conv and bn will still be executed, so after fusion, Conv and Bn must be replaced with Identity # (bn): Identity() # (classifier): Conv2d( # 3, 2, kernel_size=(6, 6), stride=(1, 1) # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # (weight_fake_quant): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1.]), zero_point=tensor([0, 0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (loss): CrossEntropyLoss() # (_generated_add_0): ConvAdd2d( # Automatically replace '+' with Module form, and fuse Conv and Bn into it # 3, 3, kernel_size=(1, 1), stride=(1, 1) # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # (weight_fake_quant): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1., 1.]), zero_point=tensor([0, 0, 0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # ) qat_model.train() set_fake_quantize(qat_model, FakeQuantState.QAT) for _ in range(3): ret = qat_model(*train_example_input) ret.backward()
Attention
  1. When dynamic code blocks involve operator replacement or fusion, they must be marked with Tracer.dynamic_block. Otherwise, it will lead to quantization information confusion or forward errors.
  2. Parts of the model where the call count changes (sub-modules or dynamic blocks), if only executed once during the trace, may get fused with non-dynamic parts, leading to forward errors.

Model Check

When example_inputs is provided, prepare will perform a model structure check by default. If the check completes, a model_check_result.txt file can be found in the running directory. If the check fails, you need to modify the model based on the warning prompts or call horizon_plugin_pytorch.utils.check_model.check_qat_model separately to check the model. The check process is the same as check_qat_model in the debug tool, and the analysis of the result file is detailed in the check_qat_model related documentation.