Accuracy Tuning Tool Guide

Due to the error in the floating point to fixed point process, you will inevitably encounter the problem of quantization model accuracy dropout when using quantization training tools. Typically, there are several reasons for this:

  1. The original floating-point model is not favorable for quantization, such as the existence of shared ops or shared structures.

  2. The QAT network structure or configuration is abnormal, such as there is a pattern without fuse in the model, and the output is not set with high accuracy.

  3. Some operators are sensitive to quantization, and the quantization error of the operator accumulates layer by layer in the forward propagation process, which ultimately leads to a large error in the model output.

For the above cases, we provides the accuracy tuning tool to help you quickly locate and solve the accuracy problems, which mainly includes the following modules:

  • Model Structure Checking: check if there are shared ops, patterns without fuse, or quantization configurations in the model that do not meet expectations.

  • QuantAnalysis Class: automatically compare and analyze the two models to locate anomalous operators or quantization-sensitive ops in the quantization model.

  • ModelProfiler Class and HbirModelProfiler Class: get information about the numerical characteristics of each op in the model, such as the maximum and minimum values of the inputs and outputs. The functionality of these two classes is identical, the difference is that HbirModelProfiler accepts as input only the qat hbir model. Usually you don't need to call this module manually, you can get the numerical information of both models directly from QuantAnalysis.run.

Quickstart

When encountering quantization model accuracy dropout problems, we recommend using the accuracy tuning tool according to the following process.

  1. Check if there are any unfavorable structures or abnormal configurations in the model.

  2. Use QuantAnalysis module to analyze the model as follows:

    1). Find a bad case as an input to the model. The bad case is the input that has the largest difference between the outputs of the baseline model and the model to be analyzed.

    2). Perform quantization sensitivity analysis, the current experience is that the first n L1 sensitivities are usually the quantitative sensitivity ops (the value of n varies from model to model, and there is no automatic method to determine it, so we need to try it manually, e.g., the first 10, 20...). Set the quantization sensitive op to high accuracy quantization (e.g., int16 quantization), and redo the quantization process.

    3). Or compare the inputs and outputs of the two models layer by layer to check whether there is an abnormal quantization op such as the data range is too large or the scale is unreasonable, e.g., certain ops with physical meanings should be set to a fixed scale.

The overall flow chart is as follows:

new_debug_flow

A whole example is as follows:

from copy import deepcopy import torch from torch import nn from torch.quantization import DeQuantStub, QuantStub from horizon_plugin_pytorch.march import March, set_march from horizon_plugin_pytorch.quantization.qconfig import ( default_qat_8bit_fake_quant_qconfig, ) from horizon_plugin_pytorch.quantization.quantize_fx import prepare from horizon_plugin_pytorch.quantization import hbdk4 as hb4 from horizon_plugin_pytorch.utils.check_model import check_qat_model from horizon_plugin_profiler import QuantAnalysis, ModelProfiler class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(3, 3, 1) self.relu = nn.ReLU() self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv(x) x = self.relu(x) x = torch.nn.functional.interpolate( x, scale_factor=1.3, mode="bilinear", align_corners=False ) x = self.dequant(x) return x data = torch.rand((1, 3, 32, 32)) float_net = Net() float_net(data) set_march(March.NASH) float_net.qconfig = default_qat_8bit_fake_quant_qconfig qat_net = deepcopy(float_net) qat_net = prepare(qat_net, example_input) ############################### Model Structure Checking ############################## # verify that the prompted exception layer is as expected check_qat_model(qat_net, data, save_results=True) ########################################################################## qat_net(data) # export hbir model qat_hbir = hb4.export(qat_net, (data,)) ############################### quant analysis ############################ # 1. initialization qa = QuantAnalysis( baseline_model=float_net, analysis_model=qat_net, analysis_model_type="fake_quant", out_dir="./floatvsqat", ) # compare qat and qat hbir is also supported # qa = QuantAnalysis( # baseline_model=qat_net, # analysis_model=qat_hbir, # analysis_model_type="fake_quant", # out_dir="./qatvshbir", # ) # 2. set the badcase input qa.set_bad_case(data) # in practice, it is recommended to use auto_find_bad_case to search for bad cases across the dataloader # setting the num_steps parameter to control the search range is also supported # qa.auto_find_bad_case(your_dataloader, num_steps=100) # 3. run two model qa.run() # 4. compare the two model layer-by-layer. Verify that the abnormal layer indicated by abnormal_layer_advisor.txt is as expected # qa.compare_per_layer() # 5. calculate sensitivity nodes. You can set the topk sorted sensitivity nodes to high accuracy to try to improve the quantization model accuracy # !!! sensitivity calculation for qat_hbir model is not supported qa.sensitivity() ##########################################################################

API Reference

Model Structure Checking

# from horizon_plugin_pytorch.utils.check_model import check_qat_model def check_qat_model( model: torch.nn.Module, example_inputs: Any, save_results: bool = False, out_dir: Optional[str] = None, ):

Check if there are structures in the calibration/qat model that are not favorable for quantization and if the quantization qconfig configuration is as expected.

Parameters

  • model: model to be checked.

  • example_inputs: model inputs.

  • save_results: whether to save the check results to a txt file. Default is False.

  • out_dir: save path of the result file 'model_check_result.txt'. Default is empty, save to current path.

Output

  • screen output: abnormal layers that check out.

  • model_check_result.txt: generated when save_results = True. It consists of 5 main parts:

    1). Unfused pattern.

    2). The calling times of each module. Normally each op is called only 1 time, 0 means it is not called, more than 1 time means it is shared many times.

    3). The qconfig configuration for each op output.

    4). The qconfig configuration for each op weight (if any).

    5). Exception qconfig hints (if any).

Fusable modules are listed below: name type ------ ----------------------------------------------------- conv <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> relu <class 'horizon_plugin_pytorch.nn.qat.relu.ReLU'> Each module called times: name called times ------- -------------- conv 1 relu 1 quant 1 dequant 1 Each layer out qconfig: +---------------+-----------------------------------------------------------+---------------+---------------+----------------+-----------------------------+ | Module Name | Module Type | Input dtype | out dtype | ch_axis | observer | |---------------+-----------------------------------------------------------+---------------+---------------+----------------+-----------------------------| | quant | <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> | torch.float32 | qint8 | -1 | MovingAverageMinMaxObserver | | conv | <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> | qint8 | qint8 | -1 | MovingAverageMinMaxObserver | | relu | <class 'horizon_plugin_pytorch.nn.qat.relu.ReLU'> | qint8 | qint8 | qconfig = None | | | dequant | <class 'horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub'> | qint8 | torch.float32 | qconfig = None | | +---------------+-----------------------------------------------------------+---------------+---------------+----------------+-----------------------------+ Weight qconfig: +---------------+-------------------------------------------------------+----------------+-----------+---------------------------------------+ | Module Name | Module Type | weight dtype | ch_axis | observer | |---------------+-------------------------------------------------------+----------------+-----------+---------------------------------------| | conv | <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> | qint8 | 0 | MovingAveragePerChannelMinMaxObserver | +---------------+-------------------------------------------------------+----------------+-----------+---------------------------------------+
Note

The prepareinterface has integrated this check. Please pay attention to the inspection results output by this interface and make targeted adjustments to the model based on the inspection results.

QuantAnalysis Class

QuantAnalysis class can automatically find the bad case with the largest output of two models, and use it as input to compare the output of two models layer by layer. In addition, QuantAnalysis class also provides the function of calculating the sensitivity, you can try to set the node with the topk sensitivity ranking with high accuracy, such as int16 quantization, to improve the accuracy of quanitized model.

class QuantAnalysis(object): def __init__( self, baseline_model: Union[torch.nn.Module, HbirModule], analysis_model: Union[torch.nn.Module, HbirModule], analysis_model_type: str, device_ids: Union[List[int], int] = None, post_process: Optional[Callable] = None, out_dir: Optional[str] = None, )

Parameters

  • baseline_model: baseline model (high accuracy).

  • analysis_model: model to be analyzed ( accuracy dropping points).

  • analysis_model_type: model to be analyzed. Support two types input:

    • fake_quant: the model to be analyzed can be a calibration/qat model with dropped accuracy, in which case the baseline model can be either the original floating-point model or a accuracy-compliant calibration/qat model in a mixed int8/int16 configuration.

    • quantized: the model to be analyzed is a fixed-point problem with accuracy dropping out, in which case the baseline model must be a accuracy-compliant calibration/qat model.

  • device_ids: GPU device ids to run analysis. Default None.

  • post_process: post process function which performs on model output.

  • out_dir: specify the output directory for the comparison results.

The methods in this class are as follows.

auto_find_bad_case

def auto_find_bad_case( self, data_generator: Iterable, num_steps: Optional[int] = None, metric: str = "L1", device: Optional[Union[torch.device, str, int]] = None, custom_metric_func: Optional[Callable] = None, custom_metric_order_seq: Optional[str] = None, cached_attrs: Optional[Tuple[str, ...]] = None, ):

Automatically find the badcase that causes the worst output for the two models.

Parameters

  • data_generator: dataloader or a custom iterator that produces one piece of data per iteration.

  • num_steps: number of iteration steps.

  • metric: Specify which metric to use as the metric for the badcase. default is to use the worst result of L1. Support Cosine/MSE/L1/KL/SQNR/custom. The custom means use custom metric calculation method, in which case custom_metric_func and custom_metric_order_seq must not be None.

  • device: specify the model run device.

  • custom_metric_func: customize the model output comparison function.

  • custom_metric_order_seq: customize the sorting rule of the model output comparison function, only "ascending"/"descending" is supported, which means ascending/descending.

  • cached_attrs: cached attrs to use as input. Usually used in sequence model. For instance, some results of the first frame must be treated as input when running the second frame, Default None.

Note

Function auto_find_bad_casegoes through data_generator, runs baseline model and analysis model, computes each output results on Cosine/MSE/L1/KL/SQNR metrics and finds the baddest input case on each metric.

Output

  • badcase.txt: It has three parts.

    • The baddest input index in data_generator of each output under different metrics.

    • The baddest metric value computed by the baddest input of each output.

    • The baddest input index of each metric.

      The bad case input index of each output: Name/Index Cosine MSE L1 KL SQNR ------------ -------- ----- ---- ---- ------ box0 16 0 0 97 16 box1 77 32 32 81 77 box2 66 46 56 79 66 scores 61 96 100 96 60 centerness 0 0 0 0 0 yawness 17 76 18 18 39 The metric results of each badcase: Name/Index Cosine MSE L1 KL SQNR ------------ --------- --------- --------- ----------- --------- box0 0.906329 84.7319 2.97311 0.323674 7.40183 box1 0.968623 38.602 2.98777 0.268769 11.4491 box2 0.799405 237.053 4.37895 0.0386395 4.28581 scores 0.388762 0.0080612 0.0395061 1.43469e-05 -0.749675 centerness 0.904206 0.0456813 0.178009 4.34062e-05 5.28536 yawness -0.325684 4.87353 1.27329 0.141269 -3.66645 The bad case input index of the worst output: metric dataloader index -------- ------------------ Cosine 17 MSE 46 L1 56 KL 97 SQNR 39
    • badcase.pt: The baddest input data under the metric pass by parameter metric. It is used as default input of function run.

set_bad_case

def set_bad_case( self, data: Any, baseline_model_cached_attr=None, analysis_model_cached_attr=None, ):

Set the badcase manually.

Attention

Usually, we suggest that you find badcase by function auto_find_bad_case. If the manual set badcase is not the actual badcase, it is difficult for quant analysis tool to find quantization sensitive layers.

Parameters

  • data: badcase input.

  • baseline_model_cached_attr: baseline model cached attr.

  • analysis_model_cached_attr: analysis model cached attr.

load_bad_case

def load_bad_case(self, filename: Optional[str] = None)

Load badcase from the specified file.

Parameters

  • filename: specified file path. Defaultly, it loads badcase.pt saved by function auto_find_bad_case from directory specified by out_dir.

save_bad_case

def save_bad_case(self)

Save badcase to the {self.out_dir}/badcase.pt file.

Attention

It is used with set_bad_case. Usually, you do not need to invoke this function.

set_model_profiler_dir

def set_model_profiler_dir( self, baseline_model_profiler_path: str, analysis_model_profiler_path: str, ):

Specify the path to save the output of model_profiler manually.

In some cases, the ModelProfiler is defined and run before QuantAnalysis is initialized, in which case you can directly specify the path to the existing ModelProfiler, skipping the run step of QuantAnalysis and comparing the output of the two models directly.

Parameters

  • baseline_model_profiler_path: profiler path for the baseline model.

  • analysis_model_profiler_path: profiler path for the model to be analyzed.

run

def run( self, device: Optional[Union[torch.device, str, int]] = None, index: Optional[int] = None, )

Run the two models and save the results for each layer in the model separately.

Parameters

  • device: model running device.

  • index: use which index input as example input.

Attention

Only index found by auto_find_bad_case and shown in badcase.txtis allowed to be parameter.

compare_per_layer

def compare_per_layer( self, prefixes: Tuple[str, ...] = None, types: Tuple[Type, ...] = None, ):

Compare the results of each layer in the two models.

Parameters

  • prefixes: the prefix of ops.

  • types: the types of ops.

Note

Usually you do not need to specify the prefixes and typesparameters. If you want to skip the comparison of certain ops with less quantitative impact based on some prior experience, or want to save time, you can use these two parameters to specify the comparison of certain ops or a certain type of ops.

Output

  • abnormal_layer_advisor.txt: all anomaly layers, including cases of low similarity/data range too large/inputs not normalized/outputs without high accuracy.

  • profiler.html: show visually all metrics and the data range diff for each layer in the model.

profiler_html

  • compare_per_layer_out.txt: show the specific information of each layer in the model in the form of a table, including various metrics, data ranges, quantized dtype, etc.. Each column from left to right represents:

    • Index: op index.

    • mod_name: name of the op, if the op is of module type, the prefix name of the module in the model will be shown, if it is of function type, it will not be shown.

    • base_op_type: type of the op in the base model, may be module type or function name.

    • analy_op_type: type of the op in the model to be analyzed, could be module type or function name.

    • Shape: shape of the op.

    • quant_dtype: quantized type output of the op.

    • Qscale: quantized scale output of the op.

    • Cosine: cosine similarity of the op output in the two models.

    • MSE: MSE distance of the op output in the two models.

    • L1: L1 distance of the op output in the two models.

    • KL: KL similarity of the op output in the two models.

    • SQNR: SQNR similarity of the op output in the two models.

    • Atol: absolute error of the op output in the two models.

    • Rtol: relative error of the op output in the two models.

    • base_model_min: minimum value of the op output in the baseline model.

    • analy_model_min: minimum value of the op output in the model to be analyzed.

    • base_model_max: maximum value of the op output in the baseline model.

    • analy_model_max: maximum value of the op output in the model to be analyzed.

    • base_model_mean: average of the op output in the baseline model.

    • analy_model_mean: average of the op output in the model to be analyzed.

    • base_model_var: variance of the op output in the baseline model.

    • analy_model_var: variance of the op output in the model to be analyzed.

      +----+------------+--------------------------------------------------------------------+--------------------------------------------------------------------+----------------------------+---------------+-----------+-----------+-----------+-----------+-----------+------------+-----------+-------------------------------------------------+------------------+-------------------+------------------+-------------------+-------------------+--------------------+------------------+-------------------+ | | mod_name | base_op_type | analy_op_type | shape | quant_dtype | qscale | Cosine | MSE | L1 | KL | SQNR | Atol | Rtol | base_model_min | analy_model_min | base_model_max | analy_model_max | base_model_mean | analy_model_mean | base_model_var | analy_model_var | |----+------------+--------------------------------------------------------------------+--------------------------------------------------------------------+----------------------------+---------------+-----------+-----------+-----------+-----------+-----------+------------+-----------+-------------------------------------------------+------------------+-------------------+------------------+-------------------+-------------------+--------------------+------------------+-------------------| | 0 | quant | torch.ao.quantization.stubs.QuantStub | horizon_plugin_pytorch.nn.qat.stubs.QuantStub | torch.Size([1, 3, 32, 32]) | qint8 | 0.0078354 | 0.9999924 | 0.0000052 | 0.0019757 | 0.0000006 | 48.1179886 | 0.0039178 | 1.0000000 | 0.0003164 | 0.0000000 | 0.9990171 | 0.9950994 | 0.5015678 | 0.5014852 | 0.0846284 | 0.0846521 | | 1 | conv | torch.nn.modules.conv.Conv2d | horizon_plugin_pytorch.nn.qat.conv2d.Conv2d | torch.Size([1, 3, 32, 32]) | qint8 | 0.0060428 | 0.9999037 | 0.0000085 | 0.0023614 | 0.0000012 | 37.1519432 | 0.0096008 | 48.2379990 | -0.7708085 | -0.7674332 | 0.4674263 | 0.4652941 | -0.0411330 | -0.0412943 | 0.0423415 | 0.0422743 | | 2 | relu | torch.nn.modules.activation.ReLU | horizon_plugin_pytorch.nn.qat.relu.ReLU | torch.Size([1, 3, 32, 32]) | qint8 | 0.0060428 | 0.9998640 | 0.0000037 | 0.0010231 | 0.0000004 | 35.5429153 | 0.0093980 | 48.2379990 | 0.0000000 | 0.0000000 | 0.4674263 | 0.4652941 | 0.0641222 | 0.0639115 | 0.0090316 | 0.0089839 | | 3 | | horizon_plugin_pytorch.nn.interpolate.autocasted_interpolate_outer | horizon_plugin_pytorch.nn.interpolate.autocasted_interpolate_outer | torch.Size([1, 3, 41, 41]) | qint8 | 0.0060428 | 0.9234583 | 0.0012933 | 0.0245362 | 0.0001882 | 8.1621437 | 0.1928777 | 340282346638528859811704183484516925440.0000000 | 0.0000000 | 0.0000000 | 0.3509629 | 0.3504813 | 0.0643483 | 0.0639483 | 0.0043305 | 0.0043366 | | 4 | dequant | torch.ao.quantization.stubs.DeQuantStub | horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub | torch.Size([1, 3, 41, 41]) | torch.float32 | | 0.9234583 | 0.0012933 | 0.0245362 | 0.0001882 | 8.1621437 | 0.1928777 | 340282346638528859811704183484516925440.0000000 | 0.0000000 | 0.0000000 | 0.3509629 | 0.3504813 | 0.0643483 | 0.0639483 | 0.0043305 | 0.0043366 | +----+------------+--------------------------------------------------------------------+--------------------------------------------------------------------+----------------------------+---------------+-----------+-----------+-----------+-----------+-----------+------------+-----------+-------------------------------------------------+------------------+-------------------+------------------+-------------------+-------------------+--------------------+------------------+-------------------+
  • compare_per_layer_out.csv: show the specific information of each layer in csv format. The content is exactly the same as compare_per_layer_out.txt, and the csv file format is convenient for you to open and analyze by excel and other software.

sensitivity

def sensitivity( self, device: Optional[torch.device] = None, metric: str = "L1", reserve: bool = False ):

Sensitivity ordering of individual nodes in the model. Applies to the float conversion to calibration/qat accuracy dropout problem.

Attention

The sensitivity function is not supported for calculating the sensitivity of hbir models.

Parameters

  • device: specify the model running device.

  • metric: metric for similarity ordering, default is L1, support Cosine/MSE/L1/KL/SQNR.

  • reserve: whether to print sensitivity nodes in reverse order to support returning some int16 operators to int8 to improve on-board performance.

Output

  • sensitive_ops.txt. The file is organized in order of quantization sensitivity from highest to lowest op. Each column from left to right represents:

    • op_name: op name.

    • sensitive_type: Type of calculating quantization sensitivities, including:

      • activation: quantization sensitivity to quantize only the output of this op.

      • weight: quantization sensitivity of the weight of the op only.

    • op_type: op type.

    • metric: the metric for calculating sensitivity. Sort the metrics in descending order of sensitivity. Support Cosine/L1/MSE/KL/SQNR. L1 is used by default.

      • L1:value range [0, ++\infty], the higher the value, the more sensitive the op is to quantization (in descending order).

      • Cosine:value range [0,1], the closer to 0, the more sensitive the op is to quantization (in descending order).

      • MSE:value range [0, ++\infty], the larger the value, the more sensitive the op is to quantization (in descending order).

      • KL:Range [0, ++\infty], the larger the value, the more sensitive the op is to quantization (in descending order).

      • SQNR: range [0, ++\infty], the smaller the value, the more sensitive the op is to quantization (in descending order).

    • quant_dtype: the quant dtype of this op output. Usually qint8/qint16.

  • sensitive_ops.pt. A sensitivity-ordered list saved using torch.save for your subsequent loading use. The format of the list is described in the Return Values section.

Return Value

Sensitivity List, each element in the List is a sub-list recording an op's sensitivity information. each item in the sub-list from left to right is [op_name, sensitive_type, op_type, metric, quant_dtype] .

An example of the whole List is as follows.

[ [op1, "activation", op1_type, L1, qint8], [op2, "activation", op2_type, L1, qint8], [op3, "activation", op3_type, L1, qint8], [op1, "weight", op1_type, L1, qint8], ... ]

You can configure the ops with the top n quantization sensitivities with high accuracy (e.g., int16) to try to improve the quantized model accuracy.

op_name sensitive_type op_type L1 --------- ---------------- ------------------------------------------------------- --------- quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 0.0245567 conv activation <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> 0.0245275 conv both <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> 0.0245275 conv weight <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> 0.024501

clean

def clean(self)

Clears intermediate results. Only files such as comparison results are retained.

ModelProfiler Class

Statistic about the inputs and outputs of each layer of operators in the forward process of the model.

# from horizon_plugin_profiler import ModelProfiler class ModelProfiler(object): def __init__( self, model: torch.nn.Module, out_dir: str, )

Parameters

  • model: model that requires statistics.

  • out_dir: path where the associated file is saved.

Note

This class only supports use by means of a with statement.

with ModelProfiler(net, "./profiler_dir") as p: net(data) p.get_info_manager.table() p.get_info_manager.tensorboard()

The methods in this class are as follows.

get_info_manager

def get_info_manager(self)

Get the structure that manages the information for each op.

Return Value

The structure OpRunningInfoManager manages the information stored for each op. Two of the important interfaces are as follows.

table
class OpRunningInfoManager: def table( self, out_dir: str = None, prefixes: Tuple[str, ...] = None, types: Tuple[Type, ...] = None, with_stack: bool = False, )

Show individual model statistics in a table. Store to the statistic.txt file.

Parameters

  • out_dir: storage path of statistic.txt file, default None, store to self.out_dir.

  • prefixes: prefixes of ops in the model to be counted. Default is all ops.

  • types: types of the ops in the model to be counted, defaults to all ops.

  • with_stack: if or not show the position of each op in the code.

Output

statistic.txt file, each column from left to right reads.

  • Index: op index.

  • Op Name: op type, module class name or function name.

  • Mod Name: if it is module class, the prefix name of the module in the model; if it is function type, the prefix name of the module where the function is located.

  • Attr: input/output/weight/bias.

  • Dtype: data type of the tensor.

  • Scale: scale of the tensor.

  • Min: minimum value of the current tensor.

  • Max: maximum value of the current tensor.

  • Mean: average value of the current tensor.

  • Var: variance of the values in the current tensor.

  • Shape: tensor shape.

+---------+--------------------------------------------------------------------+------------+--------+---------------+-----------+------------+-----------+------------+-----------+----------------------------+ | Index | Op Name | Mod Name | Attr | Dtype | Scale | Min | Max | Mean | Var | Shape | |---------+--------------------------------------------------------------------+------------+--------+---------------+-----------+------------+-----------+------------+-----------+----------------------------| | 0 | horizon_plugin_pytorch.nn.qat.stubs.QuantStub | quant | input | torch.float32 | | 0.0003164 | 0.9990171 | 0.5015678 | 0.0846284 | torch.Size([1, 3, 32, 32]) | | 0 | horizon_plugin_pytorch.nn.qat.stubs.QuantStub | quant | output | qint8 | 0.0078354 | 0.0000000 | 0.9950994 | 0.5014852 | 0.0846521 | torch.Size([1, 3, 32, 32]) | | 1 | horizon_plugin_pytorch.nn.qat.conv2d.Conv2d | conv | input | qint8 | 0.0078354 | 0.0000000 | 0.9950994 | 0.5014852 | 0.0846521 | torch.Size([1, 3, 32, 32]) | | 1 | horizon_plugin_pytorch.nn.qat.conv2d.Conv2d | conv | weight | torch.float32 | | -0.5315086 | 0.5750652 | 0.0269936 | 0.1615299 | torch.Size([3, 3, 1, 1]) | | 1 | horizon_plugin_pytorch.nn.qat.conv2d.Conv2d | conv | bias | torch.float32 | | -0.4963555 | 0.4448483 | -0.0851902 | 0.2320642 | torch.Size([3]) | | 1 | horizon_plugin_pytorch.nn.qat.conv2d.Conv2d | conv | output | qint8 | 0.0060428 | -0.7674332 | 0.4652941 | -0.0412943 | 0.0422743 | torch.Size([1, 3, 32, 32]) | | 2 | horizon_plugin_pytorch.nn.qat.relu.ReLU | relu | input | qint8 | 0.0060428 | -0.7674332 | 0.4652941 | -0.0412943 | 0.0422743 | torch.Size([1, 3, 32, 32]) | | 2 | horizon_plugin_pytorch.nn.qat.relu.ReLU | relu | output | qint8 | 0.0060428 | 0.0000000 | 0.4652941 | 0.0639115 | 0.0089839 | torch.Size([1, 3, 32, 32]) | | 3 | horizon_plugin_pytorch.nn.interpolate.autocasted_interpolate_outer | | input | qint8 | 0.0060428 | 0.0000000 | 0.4652941 | 0.0639115 | 0.0089839 | torch.Size([1, 3, 32, 32]) | | 3 | horizon_plugin_pytorch.nn.interpolate.autocasted_interpolate_outer | | output | qint8 | 0.0060428 | 0.0000000 | 0.3504813 | 0.0639483 | 0.0043366 | torch.Size([1, 3, 41, 41]) | | 4 | horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub | dequant | input | qint8 | 0.0060428 | 0.0000000 | 0.3504813 | 0.0639483 | 0.0043366 | torch.Size([1, 3, 41, 41]) | | 4 | horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub | dequant | output | torch.float32 | | 0.0000000 | 0.3504813 | 0.0639483 | 0.0043366 | torch.Size([1, 3, 41, 41]) | +---------+--------------------------------------------------------------------+------------+--------+---------------+-----------+------------+-----------+------------+-----------+----------------------------+
tensorboard
class OpRunningInfoManager: def tensorboard( self, out_dir: str = None, prefixes: Tuple[str, ...] = None, types: Tuple[Type, ...] = None, force_per_channel: bool = False, ):

Show the input and output histograms for each layer in the tensorboard.

Parameters

  • out_dir: directory where tensorboard related files are kept. Default is self.out_dir/tensorboard.

  • prefixes: prefixes of the ops in the model to be counted, default is all.

  • types: types of the ops in the model to be counted, default is all.

  • force_per_channel: if or not to display the histogram in per_channel quantization.

Output

The tensorboard file, opened with the following screenshot.

tensorboard

HbirModelProfiler Class

The functionality and usage of this class is identical to the ModelProfiler class. Please refer to ModelProfiler Class for usage.

Attention

Due to the special format of the hbir model, the qat hbir model needs to add index 0 at forward process.

with HbirModelProfiler(qat_hbir, "./hbir_dir") as p: qat_hbir[0](data) p.get_info_manager().table() p.get_info_manager().tensorboard()