Operator Fusion
The operator fusion supported by the training tool can be divided into two main categories: 1. absorb BN. 2. fuse Add, ReLU(6).
Absorb BN
The purpose of absorbing BN is to reduce the computing workload of the model. Since BN is a linear transformation process, the parameters of BN can be absorbed into the parameters of Conv when both BN and Conv occur together,
thus eliminating the computation of BN from the deployed model.
The computation of absorption proceeds as follows:
By absorbing BN, Conv2d + BN2d can be simplified to Conv2d:
Fuse Add、ReLU(6)
Different from the CUDA Kernel Fusion that fuses the CUDA Kernel to increase the computation speed, the fusion supported by the training tool is more on the quantization level.
The BPU hardware provides optimization for common model basic structures, and when computing Conv -> Add -> ReLU operator combinations, it allows the data transfer between operators to retain a high degree of accuracy,
improving the overall numerical accuracy of the model. Therefore, when quantizing the model, we can consider Conv -> Add -> ReLU as a whole.
Since the training tool quantizes the model in terms of torch.nn.Module, in order to treat Conv -> Add -> ReLU as a whole during quantization, they need to be merged into a single Module.
The operator fusion, in addition to allowing the intermediate results to retain the high precision state, also eliminates the need to convert the intermediate results to a low precision representation, and therefore the execution speed will be faster compared to no fusion.
Since operator fusion improves both model accuracy and model speed, it should generally be done for all fusable parts.
Implementation Principle
Thanks to FX's advantage of obtaining computational graphs, the training tool can automate the analysis of the model's computational graphs, match the fusible parts according to the predefined fusion pattern, and implement the fusion operation by submodule substitution.
An example is given below:
The absorption of BN and the fusion of Add, ReLU(6) can be accomplished by the same mechanism, so no distinction is needed in fusion.
import torch
from torch import nn
from torch.quantization import DeQuantStub
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization import fuse_fx
class ModelForFusion(torch.nn.Module):
def __init__(
self,
):
super(ModelForFusion, self).__init__()
self.quantx = QuantStub()
self.quanty = QuantStub()
self.conv = nn.Conv2d(3, 3, 3)
self.bn = nn.BatchNorm2d(3)
self.relu = nn.ReLU()
self.dequant = DeQuantStub()
def forward(self, x, y):
x = self.quantx(x)
y = self.quanty(y)
x = self.conv(x)
x = self.bn(x)
x = x + y
x = self.relu(x)
x = self.dequant(x)
return x
float_model = ModelForFusion()
fused_model = fuse_fx(float_model)
print(fused_model)
"""
ModelForFusion(
(quantx): QuantStub()
(quanty): QuantStub()
(conv): Identity()
(bn): Identity()
(relu): Identity()
(dequant): DeQuantStub()
(_generated_add_0): ConvAddReLU2d(
(conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
(relu): ReLU()
)
)
def forward(self, x, y):
quantx = self.quantx(x); x = None
quanty = self.quanty(y); y = None
_generated_add_0 = self._generated_add_0
add_1 = self._generated_add_0(quantx, quanty); quantx = quanty = None
dequant = self.dequant(add_1); add_1 = None
return dequant
"""
As you can see, after performing the operator fusion operation to the model, the BN is absorbed into the Conv, and the Conv, Add, and ReLU are fused into a Module (_generated_add_0).
The original submodule is replaced with Identity and is not called in the forward code.
FX automatically replaces the plus sign of x = x + y in the model with a Module form named _generated_add_0 to support operations related to operator fusion and quantization.
Operators that can be Fused
The currently supported combinations of fusable operators are shown in the following function definitions:
import operator
import torch
from torch import nn
from horizon_plugin_pytorch import nn as horizon_nn
def register_fusion_patterns():
convs = (
nn.Conv2d,
nn.ConvTranspose2d,
nn.Conv3d,
nn.Linear,
)
bns = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
adds = (
nn.quantized.FloatFunctional.add,
horizon_nn.quantized.FloatFunctional.add,
torch.add,
operator.add, # the plus sign used in the code
)
relus = (nn.ReLU, nn.ReLU6, nn.functional.relu, nn.functional.relu6)
for conv in convs:
for bn in bns:
for add in adds:
for relu in relus:
# conv bn
register_fusion_pattern((bn, conv))(ConvBNAddReLUFusion)
# conv relu
register_fusion_pattern((relu, conv))(ConvBNAddReLUFusion)
# conv add
register_fusion_pattern((add, conv, MatchAllNode))(
ConvBNAddReLUFusion
) # the output of conv is used as the first input to add
register_fusion_pattern((add, MatchAllNode, conv))(
ConvBNAddedReLUFusion
) # The output of conv is used as the second input to add
# conv bn relu
register_fusion_pattern((relu, (bn, conv)))(
ConvBNAddReLUFusion
)
# conv bn add
register_fusion_pattern((add, (bn, conv), MatchAllNode))(
ConvBNAddReLUFusion
)
register_fusion_pattern((add, MatchAllNode, (bn, conv)))(
ConvBNAddedReLUFusion
)
# conv add relu
register_fusion_pattern((relu, (add, conv, MatchAllNode)))(
ConvBNAddReLUFusion
)
register_fusion_pattern((relu, (add, MatchAllNode, conv)))(
ConvBNAddedReLUFusion
)
# conv bn add relu
register_fusion_pattern(
(relu, (add, (bn, conv), MatchAllNode))
)(ConvBNAddReLUFusion)
register_fusion_pattern(
(relu, (add, MatchAllNode, (bn, conv)))
)(ConvBNAddedReLUFusion)