Operator Fusion

The operator fusion supported by the training tool can be divided into two main categories: 1. absorb BN. 2. fuse Add, ReLU(6).

Absorb BN

The purpose of absorbing BN is to reduce the computing workload of the model. Since BN is a linear transformation process, the parameters of BN can be absorbed into the parameters of Conv when both BN and Conv occur together, thus eliminating the computation of BN from the deployed model.

The computation of absorption proceeds as follows:

fuse_bn

By absorbing BN, Conv2d + BN2d can be simplified to Conv2d:

absorb_bn

Fuse Add、ReLU(6)

Different from the CUDA Kernel Fusion that fuses the CUDA Kernel to increase the computation speed, the fusion supported by the training tool is more on the quantization level.

The BPU hardware provides optimization for common model basic structures, and when computing Conv -> Add -> ReLU operator combinations, it allows the data transfer between operators to retain a high degree of accuracy, improving the overall numerical accuracy of the model. Therefore, when quantizing the model, we can consider Conv -> Add -> ReLU as a whole.

Since the training tool quantizes the model in terms of torch.nn.Module, in order to treat Conv -> Add -> ReLU as a whole during quantization, they need to be merged into a single Module.

The operator fusion, in addition to allowing the intermediate results to retain the high precision state, also eliminates the need to convert the intermediate results to a low precision representation, and therefore the execution speed will be faster compared to no fusion.

Since operator fusion improves both model accuracy and model speed, it should generally be done for all fusable parts.

Implementation Principle

Thanks to FX's advantage of obtaining computational graphs, the training tool can automate the analysis of the model's computational graphs, match the fusible parts according to the predefined fusion pattern, and implement the fusion operation by submodule substitution. An example is given below:

The absorption of BN and the fusion of Add, ReLU(6) can be accomplished by the same mechanism, so no distinction is needed in fusion.

import torch from torch import nn from torch.quantization import DeQuantStub from horizon_plugin_pytorch.quantization import QuantStub from horizon_plugin_pytorch.quantization import fuse_fx class ModelForFusion(torch.nn.Module): def __init__( self, ): super(ModelForFusion, self).__init__() self.quantx = QuantStub() self.quanty = QuantStub() self.conv = nn.Conv2d(3, 3, 3) self.bn = nn.BatchNorm2d(3) self.relu = nn.ReLU() self.dequant = DeQuantStub() def forward(self, x, y): x = self.quantx(x) y = self.quanty(y) x = self.conv(x) x = self.bn(x) x = x + y x = self.relu(x) x = self.dequant(x) return x float_model = ModelForFusion() fused_model = fuse_fx(float_model) print(fused_model) """ ModelForFusion( (quantx): QuantStub() (quanty): QuantStub() (conv): Identity() (bn): Identity() (relu): Identity() (dequant): DeQuantStub() (_generated_add_0): ConvAddReLU2d( (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (relu): ReLU() ) ) def forward(self, x, y): quantx = self.quantx(x); x = None quanty = self.quanty(y); y = None _generated_add_0 = self._generated_add_0 add_1 = self._generated_add_0(quantx, quanty); quantx = quanty = None dequant = self.dequant(add_1); add_1 = None return dequant """

As you can see, after performing the operator fusion operation to the model, the BN is absorbed into the Conv, and the Conv, Add, and ReLU are fused into a Module (_generated_add_0). The original submodule is replaced with Identity and is not called in the forward code.

FX automatically replaces the plus sign of x = x + y in the model with a Module form named _generated_add_0 to support operations related to operator fusion and quantization.

Operators that can be Fused

The currently supported combinations of fusable operators are shown in the following function definitions:

import operator import torch from torch import nn from horizon_plugin_pytorch import nn as horizon_nn def register_fusion_patterns(): convs = ( nn.Conv2d, nn.ConvTranspose2d, nn.Conv3d, nn.Linear, ) bns = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) adds = ( nn.quantized.FloatFunctional.add, horizon_nn.quantized.FloatFunctional.add, torch.add, operator.add, # the plus sign used in the code ) relus = (nn.ReLU, nn.ReLU6, nn.functional.relu, nn.functional.relu6) for conv in convs: for bn in bns: for add in adds: for relu in relus: # conv bn register_fusion_pattern((bn, conv))(ConvBNAddReLUFusion) # conv relu register_fusion_pattern((relu, conv))(ConvBNAddReLUFusion) # conv add register_fusion_pattern((add, conv, MatchAllNode))( ConvBNAddReLUFusion ) # the output of conv is used as the first input to add register_fusion_pattern((add, MatchAllNode, conv))( ConvBNAddedReLUFusion ) # The output of conv is used as the second input to add # conv bn relu register_fusion_pattern((relu, (bn, conv)))( ConvBNAddReLUFusion ) # conv bn add register_fusion_pattern((add, (bn, conv), MatchAllNode))( ConvBNAddReLUFusion ) register_fusion_pattern((add, MatchAllNode, (bn, conv)))( ConvBNAddedReLUFusion ) # conv add relu register_fusion_pattern((relu, (add, conv, MatchAllNode)))( ConvBNAddReLUFusion ) register_fusion_pattern((relu, (add, MatchAllNode, conv)))( ConvBNAddedReLUFusion ) # conv bn add relu register_fusion_pattern( (relu, (add, (bn, conv), MatchAllNode)) )(ConvBNAddReLUFusion) register_fusion_pattern( (relu, (add, MatchAllNode, (bn, conv))) )(ConvBNAddedReLUFusion)