Quantization Accuracy Tuning Guide

Quantization accuracy tuning involves two aspects:

  1. Model structure and quantization configuration check. The main purpose is to avoid non-tuning issues affecting quantization accuracy, such as incorrect qconfig settings or using shared modules which are not quantization-friendly.

  2. Mixed precision tuning. Start with a model using high precision operators to quickly achieve the desired accuracy, establishing the upper bound of accuracy and lower bound of performance. Then, use accuracy tuning tools to analyze and adjust quantization configurations to obtain a model that balances accuracy and performance.

Model Structure and Quantization Configuration Check

After preparing the model, first check for quantization configuration errors and model structures which are not quantization-friendly. You can use the check_qat_model interface in the debug tools. Please refer to Accuracy Tuning Tool Guide for interface usage.

Attention

check_qat_model is integrated into the prepare interface, and you can directly view the model_check_result.txt in the running directory.

Operator Fusion

Check if there are modules in the model that can be fused but are not. When deploying the model on BPU, operators like conv, bn, add, and relu will be fused. In a QAT model, these operators will be replaced with a single module to avoid inserting fake quantization nodes. If these operators are not fused, additional quantization nodes will be inserted, potentially causing slight impacts on accuracy and performance. The following example shows that conv and relu are not fused.

Fusable modules are listed below: name type ------ ----------------------------------------------------- conv <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> relu <class 'horizon_plugin_pytorch.nn.qat.relu.ReLU'>

Possible causes and solutions for operator fusion errors in different prepare methods:

  1. PrepareMethod.JIT and PrepareMethod.JIT_STRIP:

    a. Operator fusion in dynamic code blocks requires marking with dynamic_block.

    b. Code with varying call counts were only executed once during tracing. Use inputs which make the code execute multiple times as example_inputs.

  2. PrepareMethod.EAGER: Missing or incorrect fuse operations. You need to check and fix the handwritten fuse logic.

  3. PrepareMethod.SYMBOLIC: Modules that can be fused are included in fx.wrap, which needs to be moved out to ensure these modules exist in the graph or use manual fusion like PrepareMethod.EAGER.

Shared Modules

Since the plugin inserts quantization nodes by module replacement, only one set of quantization information can be collected for each module. When a module is called multiple times with significantly different output data distributions, using the same set of quantization parameters will cause large errors, and the shared module needs to be copied. If the output data distribution of multiple calls is not significantly different, the shared module does not need to be copied. Here, we explain the concept of shared modules, which will help decide whether to copy shared modules during layer-by-layer comparison.

The difference of three common-understanding "shared modules":

A. A module followed by multiple modules. Module A is considered shared, but here module A is only called once, and the output data distribution is not different, so it does not affect quantization accuracy and will not be reflected in the call count check.

B. A module is called multiple times, but the output data distribution are similar. Although this can be seen in the call count check, it has little impact on quantization accuracy and does not need modification.

C. A module is called multiple times, with significantly different output distributions each time. This will be reflected in the call count check and significantly impacts quantization accuracy, requiring manual splitting.

In model_check_result.txt, you can see the call count of each module. Normally, each op is called only once. Zero means it is not called, and more than one indicates multiple calls. In the following example, conv is a shared module.

Each module called times: name called times ------- -------------- conv 2 quant 1 dequant 1 # Corresponding code # def forward(self, x): # x = self.quant(x) # x = self.conv(x) # x = self.conv(x) # x = self.dequant(x) # return x

QConfig Configuration Errors

Using qconfig incorrectly can cause the model to quantize in unexpected ways, resulting in low accuracy (e.g. mixing template and qconfig attribute settings). Here, we mainly check if the input and output of each operator meet expectations by looking at model_check_result.txt:

  1. Whether the dtype matches the settings.

  2. Whether high precision output is enabled.

Each layer out qconfig: +---------------+-----------------------------------------------------------+---------------+---------------+----------------+-----------------------------+ | Module Name | Module Type | Input dtype | out dtype | ch_axis | observer | |---------------+-----------------------------------------------------------+---------------+---------------+----------------+-----------------------------| | quant | <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> | torch.float32 | qint8 | -1 | MovingAverageMinMaxObserver | | conv | <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> | qint8 | qint8 | -1 | MovingAverageMinMaxObserver | | relu | <class 'horizon_plugin_pytorch.nn.qat.relu.ReLU'> | qint8 | qint8 | qconfig = None | | | dequant | <class 'horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub'> | qint8 | torch.float32 | qconfig = None | | +---------------+-----------------------------------------------------------+---------------+---------------+----------------+-----------------------------+ # This check result shows that all modules are int8 quantized. If you configured int16, it means the configuration did not take effect, and you need to check the usage of qconfig. Weight qconfig: +---------------+-------------------------------------------------------+----------------+-----------+---------------------------------------+ | Module Name | Module Type | weight dtype | ch_axis | observer | |---------------+-------------------------------------------------------+----------------+-----------+---------------------------------------| | conv | <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> | qint8 | 0 | MovingAveragePerChannelMinMaxObserver | +---------------+-------------------------------------------------------+----------------+-----------+---------------------------------------+

Additionally, model_check_result.txt will have prompts for abnormal qconfig. These are configurations that the tool identifies for you to double-check. They need verification to see if they meet expectations.

  1. Weight int16. J6M does not support both of input and weight being int16. If weight int16 is found, check whether the input is int16.

  2. Fixed scale. Check if the fixed scale settings are as expected.

Please check if these OPs qconfigs are expected.. +-----------------+----------------------------------------------------------------------------+------------------------------------------------------------------+ | Module Name | Module Type | Msg | |-----------------+----------------------------------------------------------------------------+------------------------------------------------------------------| | convmod1.add | <class 'horizon_plugin_pytorch.nn.qat.conv2d.ConvAddReLU2d'> | qint16 weight!!! | | convmod2.conv2d | <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> | qint16 weight!!! | | convmod3.add | <class 'horizon_plugin_pytorch.nn.qat.conv2d.ConvAddReLU2d'> | qint16 weight!!! | | shared_conv | <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> | qint16 weight!!! | | shared_conv(1) | <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> | qint16 weight!!! | | sub[sub] | <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'> | Fixed scale 3.0517578125e-05 | +-----------------+----------------------------------------------------------------------------+------------------------------------------------------------------+

Mixed Precision Tuning

Tuning Pipeline

The entire pipeline first involves all-int16 accuracy tuning to confirm the model's upper bound of accuracy and check for tool usage issues or quantization-unfriendly modules.

  1. After confirming that all-int16 accuracy meets the requirements, perform all-int8 accuracy tuning. If the accuracy is not good enough, perform int8/int16 mixed precision tuning. Starting with an all-int8 model, gradually increase the proportion of int16 operators, balancing accuracy and performance.

  2. If all-int16 accuracy does not meet the requirements, perform int16/fp16 mixed precision tuning. Ideally, int16/fp16 mixed precision tuning can solve all accuracy problems. Based on this, perform int8/int16/fp16 mixed precision tuning, fixing all fp16 operator configurations, and adjust the proportion of int16 operators as described in step 1.

Basic Tuning Methods

The goal of basic tuning methods is to get a quantization model quickly.

Calibration

  1. Adjust calibration steps. The more calibration data, the better. However, due to the marginal effect, when the data volume reaches a certain level, the improvement in accuracy will be very limited. If the training set is small, use it all for calibration. If the training set is large, select a subset for calibration to balance time and accuracy. It's recommended to perform at least 10 to 100 steps of calibration.

  2. Adjust batch size. Generally, a larger batch size is preferable. However, if the data is noisy or the model has many outliers, it may be necessary to reduce the batch size.

  3. Use inference stage preprocessing with training data for calibration. Calibration data should reflect the real distribution and can use data augmentation methods like flipping, but avoid using augmentation methods that destory the real distribution, such as rotation and mosaic.

QAT

  1. Adjust learning rate.

    a. Initial Learning Rate:Disable warmup and decay strategies. Use different fixed learning rates (e.g., 1e-3, 1e-4, etc.) to fine-tune for a few steps and select the learning rate with best accuracy. If the floating-point model uses different learning rates for different modules, replicate this in QAT.

    b. Scheduler:Align the learning rate decay strategy with the floating-point model's but ensure no warmup strategies are used.

  2. Experiment with fixing or updating input/output scales. When the accuracy of calibration model is good, fixing the input/output scale for QAT training can achieve better results. Otherwise, do not fix the scales. Both of the two strategies need to be tested as there are no definitive metrics to guide it.

  3. Training steps. Generally, QAT training steps should not exceed 20% of the floating-point training steps. Adjust training steps based on training loss and evaluation results.

Special considerations:

  1. Apart from the specific adjustments mentioned above, ensure all other QAT training configurations align with the floating-point training.

  2. If the floating-point training uses freeze BN trick, set the QAT mode to WithBN.

from horizon_plugin_pytorch.qat_mode import QATMode, set_qat_mode set_qat_mode(QATMode.WithBN)
  1. If accuracy issues arise despite tuning, such as non-converging loss or NaNs, first fine-tune the floating-point model with the same configuration to rule out issues on misalignment or overfitting.

Advanced Tuning Methods

Advanced tuning methods generally require more time and experiments. They are used when higher accuracy is required.

Setting Fixed Scale

Some parts of the model may require fixed scales due to difficulty in determining the optimal quantization scale through statistics. Common situations requiring fixed scales include operators with known output ranges.

For example: If input data represents speed in km/h with a range of [0, 200], the quantstub output range is fixed, and the scale should be set to 200 divided by the quantization range. Statistical methods might underestimate this range due to outliers being averaged out, causing error for samples exceeding this range.

In the case below, the value range of input a and b is determined, and the value range of input c is uncertain. Except for quantstub_c and the last add, all other operators need to set fixed scale.

Calibration

Experiment with different calibration methods. The plugin supports various methods, including min max, percentile, kl, mse, and mix. For tuning experiences, refer to the calibration guide.

QAT

  1. Adjust weight decay. Weight decay impacts the range of weight values. Smaller range means more quantization-friendly. Adjust weight decay during both floating-point and QAT phases if necessary.

  2. Adjust data augmentation. Quantized models have lower learning capabilities compared to floating-point models. Strong data augmentation may hinder QAT model convergence, so it's generally advisable to moderate the augmentation.

INT8/INT16 & INT16/FP16 Mixed Precision Tuning

The basic idea of mixed precision tuning is to incrementally introduce higher precision operators into a lower precision model until the desired accuracy is achieved. For int8/int16 tuning, start with an all-int8 model and incrementally increase the number of int16 operators. For int16/FP16 tuning, start with an all-int16 model and incrementally increase the number of FP16 operators.

The calibration and QAT tuning in the figure above can reference both basic and advanced tuning methods chapters. Increasing the number of high-precision operators for int16/fp16 relies on a series of debug results produced by accuracy debugging tools.

QAT accuracy debugging is entirely based on the comparison between the floating-point model and the calibration model. Generally, it is not recommended to compare the floating-point model with the QAT model directly, as they become incomparable after QAT training. Please refer to Accuracy Tuning Tool Guide for background information. First, provide a dataset to find bad cases with significant calibration error, then compare layer-by-layer and calculate quantization sensitivity based on these bad cases.

Find Badcase

All accuracy debugging operations are based on bad cases. You need to provide a sample set with low quantization accuracy. This process will traverse the sample set and compare the output error of each sample in the floating-point model and the calibration model. Generally, you do not need to provide an error measurement function.

Attention

Start from finding bad cases, models need to include some post-processing logic (removing or replacing parts of the original post-processing logic that make the model outputs completely incomparable). For example:

  1. Sigmoid should not be removed. Values in the sigmoid saturation domain are not sensitive to errors, but values near 0 are highly sensitive. Removing sigmoid will prevent accurate reflection of quantization errors within different domains.

  2. NMS should be removed. Small errors can lead to completely different NMS results, making the output incomparable.

The debug tool supports automatic replacement of sort/topk/argmax. Besides these operators, you need to check if there are similar operators in the model and post-processing, and delete all parts after these operators.

Use the auto_find_bad_case interface in the debugging tool to find bad cases.

from horizon_plugin_profiler.model_profilerv2 import QuantAnalysis # 1. Initialize the quantization analysis object qa = QuantAnalysis(float_net, calibration_model, "fake_quant") # 2. Find bad cases. If there are too many data, you can specify num_steps to find bad cases in a subset of the data qa.auto_find_bad_case(dataloader)

After finding the bad cases, check the result file. For each output of the model, the debug tool will find the worst sample under each error metric. In the example below, the model has three outputs. The first table shows the worst sample index for each output under each metric, the second table shows the corresponding errors, and the third table shows the worst bad case index under the current metric for all outputs.

The bad case input index of each output: Name/Index Cosine MSE L1 KL SQNR ------------ -------- ----- ---- ---- ------ 0-0 4 4 1 4 1 0-1 14 1 1 1 1 0-2 12 11 9 12 9 The metric results of each badcase: Name/Index Cosine MSE L1 KL SQNR ------------ -------- ----------- --------- ----------- --------- 0-0 0.969289 2.35591 0.996721 0.0828096 11.9335 0-1 0.974127 0.00964048 0.0404785 0.00104532 12.6742 0-2 0.450689 4.76039 1.08771 0.320661 -0.821809 The bad case input index of the worst output: metric dataloader index -------- ------------------ Cosine 11 MSE 17 L1 17 KL 12 SQNR 0
Attention

We can only debug the outputs related to accuracy drop. Different model outputs require different error metrics. Typically, L1 and Cosine are suitable for most issues. L1 is suitable for tasks requiring absolute error reflection, such as bbox regression, while Cosine is suitable for tasks requiring overall distribution error reflection, such as classification.

Layer-by-Layer Comparison

Layer-by-layer comparison uses the bad cases to run the floating-point model and the calibration model separately, comparing the output of each layer. This can be done using the compare_per_layer interface in the debug tool. This method is suitable for detailed accuracy analysis. If the accuracy drop is not significant, you can skip this step and analyze it later with sensitivity results.

from horizon_plugin_profiler.model_profilerv2 import QuantAnalysis # 1. Initialize the quantization analysis object qa = QuantAnalysis(float_net, calibration_model, "fake_quant") # 2. Find bad cases. If there are too many data, you can specify num_steps to find bad cases in a subset of the data qa.auto_find_bad_case(dataloader) # 3. Run the model with bad cases to get information on each layer qa.run() # 4. Compare the information of each layer qa.compare_per_layer()

The results of the layer-by-layer comparison can be viewed in generated text files. In the text file, you can see where the error starts to magnify from top to bottom.

+------+----------------------------------------+-----------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------+---------------+-----------+------------+-----------------+--------------+------------+--------------+---------------+-------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+-------------------+ | | mod_name | base_op_type | analy_op_type | shape | quant_dtype | qscale | Cosine | MSE | L1 | KL | SQNR | Atol | Rtol | base_model_min | analy_model_min | base_model_max | analy_model_max | base_model_mean | analy_model_mean | base_model_var | analy_model_var | max_atol_diff | max_qscale_diff | |------+----------------------------------------+-----------------------------------------------------------------------------+-----------------------------------------------------------+---------------------------------------+---------------+-----------+------------+-----------------+--------------+------------+--------------+---------------+-------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+-------------------| | 0 | backbone.quant | horizon_plugin_pytorch.quantization.stubs.QuantStub | horizon_plugin_pytorch.nn.qat.stubs.QuantStub | torch.Size([4, 3, 512, 960]) | qint16 | 0.0078125 | 0.9999999 | 0.0000000 | 0.0000000 | 0.0000000 | inf | 0.0000000 | 0.0000000 | -1.0000000 | -1.0000000 | 0.9843750 | 0.9843750 | -0.1114422 | -0.1114422 | 0.1048462 | 0.1048462 | 0.0000000 | 0.0000000 | | 1 | backbone.mod1.0 | torch.nn.modules.conv.Conv2d | horizon_plugin_pytorch.nn.qat.conv_bn2d.ConvBN2d | torch.Size([4, 32, 256, 480]) | qint16 | 0.0003464 | 0.4019512 | 0.4977255 | 0.4296696 | 0.0115053 | -0.3768753 | 14.2325277 | 5343133.5000000 | -5.9793143 | -15.0439596 | 8.7436047 | 17.3419971 | 0.1503869 | 0.0484397 | 0.4336909 | 0.3707855 | 14.2325277 | 41088.2696619 | | 2 | backbone.mod1.1 | torch.nn.modules.batchnorm.BatchNorm2d | torch.nn.modules.linear.Identity | torch.Size([4, 32, 256, 480]) | qint16 | 0.0003464 | 0.9999998 | 0.0000000 | 0.0000000 | 0.0000000 | inf | 0.0000000 | 0.0000000 | -15.0439596 | -15.0439596 | 17.3419971 | 17.3419971 | 0.0484397 | 0.0484397 | 0.3707855 | 0.3707855 | 0.0000000 | 0.0000000 | | 3 | backbone.mod2.0.head_layer.conv.0.0 | torch.nn.modules.conv.Conv2d | horizon_plugin_pytorch.nn.qat.conv_bn2d.ConvBNReLU2d | torch.Size([4, 64, 256, 480]) | qint16 | 0.0004594 | 0.5790146 | 49.3001938 | 4.0790396 | 0.0415040 | 0.3848250 | 164.3788757 | 4046676.7500000 | -164.3788757 | 0.0000000 | 140.9307404 | 25.1951389 | -0.5375993 | 0.2460073 | 53.5661125 | 0.2699530 | 164.3788757 | 357789.3997821 | | 4 | backbone.mod2.0.head_layer.conv.0.1 | torch.nn.modules.batchnorm.BatchNorm2d | torch.nn.modules.linear.Identity | torch.Size([4, 64, 256, 480]) | qint16 | 0.0004594 | 0.7092103 | 0.3265578 | 0.2332140 | 0.0003642 | 3.0239849 | 17.1071243 | 1.0000000 | -17.1071243 | 0.0000000 | 25.1951389 | 25.1951389 | 0.0127933 | 0.2460073 | 0.6568668 | 0.2699530 | 17.1071243 | 37235.6102222 | | 5 | backbone.mod2.0.head_layer.conv.0.2 | torch.nn.modules.activation.ReLU | torch.nn.modules.linear.Identity | torch.Size([4, 64, 256, 480]) | qint16 | 0.0004594 | 1.0000001 | 0.0000000 | 0.0000000 | 0.0000000 | inf | 0.0000000 | 0.0000000 | 0.0000000 | 0.0000000 | 25.1951389 | 25.1951389 | 0.2460073 | 0.2460073 | 0.2699530 | 0.2699530 | 0.0000000 | 0.0000000 | | 6 | backbone.mod2.0.head_layer.short_add | horizon_plugin_pytorch.nn.quantized.functional_modules.FloatFunctional.add | horizon_plugin_pytorch.nn.qat.conv_bn2d.ConvBNAddReLU2d | torch.Size([4, 32, 256, 480]) | qint16 | 0.0004441 | 0.5653002 | 1.6375992 | 0.4214573 | 0.0008538 | 1.6659310 | 39.9804993 | 1.0000000 | -39.9804993 | 0.0000000 | 19.6796150 | 19.6796150 | 0.0330326 | 0.4544899 | 2.4056008 | 0.5625318 | 39.9804993 | 90017.7454165 | | 7 | backbone.mod2.0.head_layer.relu | torch.nn.modules.activation.ReLU | torch.nn.modules.linear.Identity | torch.Size([4, 32, 256, 480]) | qint16 | 0.0004441 | 1.0000000 | 0.0000000 | 0.0000000 | 0.0000000 | inf | 0.0000000 | 0.0000000 | 0.0000000 | 0.0000000 | 19.6796150 | 19.6796150 | 0.4544899 | 0.4544899 | 0.5625318 | 0.5625318 | 0.0000000 | 0.0000000 |
Attention

After finding the operators with significant accuracy drop, first check base_model_min / base_model_max / analy_model_min / analy_model_max to confirm if there are substantial errors in the extreme values.

  1. Significant errors in min / max values: This indicates that the output range of the operator in the floating-point model significantly exceeds the range obtained from calibration. Check the scale of the operator and calculate the maximum value obtained from calibration using the dtype and scale. For instance, if the scale is 0.0078 and the dtype is int8, then the maximum value should be 0.0078 × 128 = 0.9984. Compare this with base_model_max and analy_model_max. Reasons for a very small scale may include: unreasonable calibration data (too little calibration data or biased distribution resulting in an excessively small output range), shared module etc.

  2. No significant errors in min / max values:Similarly, calculate the calibration maximum value as described in point 1, compare it with base_model_max and analy_model_max to confirm that the output range of the operator in the floating-point model does not significantly exceed the range obtained from calibration. These issues may be caused by insufficient resolution of the quantization type or an excessively large numerical range.

    a. Check if the current quantization dtype matches the numerical range. Generally, if the maximum value exceeds 10, it is not recommended to use int8 quantization.

    b. Identify why the calibration get a large numerical range. The reason may be outliers or unreasonable settings.

Calculating Quantization Sensitivity

This step evaluates the impact of quantization nodes on model accuracy. The sensitivity interface in the debug tool can be used to get quantization sensitivity. The specific evaluation method is to use badcase as the model input, individually enabling each quantization node, and comparing the error between the calibration model and the floating-point model. The error measurement metric is the same as the one used when finding badcase.

from horizon_plugin_profiler.model_profilerv2 import QuantAnalysis # 1. Initialize the quantization analysis object qa = QuantAnalysis(float_net, calibration_model, "fake_quant") # 2. Find bad cases. If there are too many data, you can specify num_steps to find bad cases in a subset of the data qa.auto_find_bad_case(dataloader) # 3. Run the model with bad cases to get information on each layer qa.run() # 4. Compare the information of each layer qa.compare_per_layer() # 5. Calculate quantization sensitivity qa.sensitivity()

In the results of quantization sensitivity, operators ranked higher have a greater impact on model accuracy and need to be set to a higher precision data type.

The sensitive_type column has two values: weight and activation, representing the cases of enabling only the weight quantization node or the output quantization node of the operator, respectively.

op_name sensitive_type op_type L1 quant_dtype --------------------------------------------------------------------------------------------------------------- ---------------- -------------------------------------------------------------------------- -------- ------------ bev_stage2_e2e_dynamic_head.head.transformer.decoder.layers.5.cross_attn.quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 1.59863 qint8 bev_stage2_e2e_dynamic_head.head.transformer.decoder.layers.5.norm3.var_mean.pre_mean activation <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'> 1.52816 qint16 bev_stage2_e2e_dynamic_head.head.ref_pts_quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 1.16427 qint8 bev_stage2_e2e_dynamic_head.head.fps_quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 1.13563 qint8 bev_stage2_e2e_dynamic_head.head.transformer.decoder.mem_bank_layer.emb_fps_queue_add activation <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'> 1.11997 qint8 bev_stage2_e2e_dynamic_head.head.transformer.decoder.mem_bank_layer.temporal_norm2.weight_mul activation <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'> 1.09876 qint8
Attention

Even for very quantization-unfriendly models, some operators should have relatively low quantization sensitivity. Therefore, in a normal sensitivity table, the sensitivity should vary, with the last few operators' sensitivities near to zero. If the errors for the last few operators are large, consider whether there are incorrect post-processing operations, such as NMS, that have not been thoroughly removed.

You need to use a sensitivity template to set the qconfig. See the qconfig section for detailed usage. Below is an example of setting the top 20% most sensitive operators to int16 in an int8/int16 mixed precision tuning. Adjust the int16 ratio continuously until you find the minimal int16 ratio that meets the accuracy requirements.

qat_model = prepare( model, example_inputs=example_input, qconfig_setter=( sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table, ratio=0.2), default_calibration_qconfig_setter, ) ) qat_model = prepare( model, example_inputs=example_input, qconfig_setter=( sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter(table, ratio=0.2), default_qat_qconfig_setter, ) )

Currently, there is no interface to set fp16 based on sensitivity in bulk. You need to set a small number of fp16 based on the int16 sensitivity results using ModuleNameQconfigSetter. Below is an example of setting the top 1 most sensitive int16 operator to fp16 in int16/fp16 mixed precision optimization.

module_name_to_qconfig = { "op_1": get_qconfig(in_dtype=torch.float16, weight_dtype=torch.float16, out_dtype=torch.float16), } qat_model = prepare( model, example_inputs=example_input, qconfig_setter=( ModuleNameQconfigSetter(module_name_to_qconfig), calibration_8bit_weight_16bit_act_qconfig_setter, ) ) qat_model = prepare( model, example_inputs=example_input, qconfig_setter=( ModuleNameQconfigSetter(module_name_to_qconfig), qat_8bit_weight_16bit_act_qconfig_setter, ) )

J6M has limited floating-point compute power. If fp16 is not absolutely necessary, try to use int8/int16 mixed precision tuning. When the all-int16 model cannot meet the accuracy requirements, introduce a small number of fp16 operators into the all-int16 model.

Two scenarios in which the all-int16 model fail to meet accuracy standards:

  1. Dual int16 needed: This manifests as certain operators having high sensitivity under both activation and weight sensitivity types. Setting weight and activation to int16 meets the accuracy requirements. Since J6M does not support using int16 for both activation and weight simultaneously, adjust the floating-point model to make one of them more quantization-friendly. Common methods include increasing weight decay and adding normalization operators.

  2. Dual int16 not needed: This manifests as certain operators having high sensitivity under either activation or weight sensitivity type, generally due to plugin usage issues or certain operators requiring fixed scale settings. Accuracy debugging can find specific problems.

INT8 / INT16 / FP16 Mixed Precision Tuning

Before performing int8/int16/fp16 mixed precision tuning, you should have completed int16/fp16 mixed precision tuning. Reuse the fp16 configuration from int16/fp16 mixed precision tuning and perform accuracy debugging based on the int8/fp16 mixed calibration model. Refer to the accuracy debugging methods in the previous section, continuously adjusting the int16 ratio until you find the minimal int16 ratio that meets the accuracy requirements.

Below is an example of setting the top 1 most sensitive int16 operator to fp16 and the top 20% most sensitive int8 operators to int16 in int8/int16/fp16 mixed precision optimization.

module_name_to_qconfig = { "op_1": get_qconfig(in_dtype=torch.float16, weight_dtype=torch.float16, out_dtype=torch.float16), } qat_model = prepare( model, example_inputs=example_input, qconfig_setter=( ModuleNameQconfigSetter(module_name_to_qconfig), sensitive_op_8bit_weight_16bit_act_calibration_setter(table, ratio=0.2), default_calibration_qconfig_setter, ) ) qat_model = prepare( model, example_inputs=example_input, qconfig_setter=( ModuleNameQconfigSetter(module_name_to_qconfig), sensitive_op_8bit_weight_16bit_act_qat_setter(table, ratio=0.2), default_qat_qconfig_setter, ) )