quantization.fuse_modules
horizon_plugin_pytorch.quantization.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=, fuse_custom_config_dict=None)
Fuses a list of modules into a single module.
Fuses only the following sequence of modules:
conv, bn;
conv, bn, relu;
conv, relu;
conv, bn, add;
conv, bn, add, relu;
conv, add;
conv, add, relu;
linear, bn;
linear, bn, relu;
linear, relu;
linear, bn, add;
linear, bn, add, relu;
linear, add;
linear, add, relu.
For these sequences, the first element in the output module list performs
the fused operation. The rest of the elements are set to nn.Identity()
- Parameters:
- model – Model containing the modules to be fused
- modules_to_fuse – list of list of module names to fuse. Can also be a
list of strings if there is only a single list of
modules to fuse.
- inplace – bool specifying if fusion happens in place on the model, by
default a new model is returned
- fuser_func – Function that takes in a list of modules and outputs a list
of fused modules of the same length. For example,
fuser_func([convModule, BNModule]) returns the list
[ConvBNModule, nn.Identity()] Defaults to
torch.ao.quantization.fuse_known_modules
- fuse_custom_config_dict – custom configuration for fusion
# Example of fuse_custom_config_dict
fuse_custom_config_dict = {
# Additional fuser_method mapping
"additional_fuser_method_mapping": {
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
},
}
- Returns:
model with fused modules. A new copy is created if inplace=True.
Examples:
>>> # xdoctest: +SKIP
>>> m = M().eval()
>>> # m is a module containing the sub-modules below
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'],
['submodule.conv', 'submodule.relu']]
>>> fused_m = fuse_modules(
m, modules_to_fuse)
>>> output = fused_m(input)
>>> m = M().eval()
>>> # Alternately provide a single list of modules to fuse
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
>>> fused_m = fuse_modules(
m, modules_to_fuse)
>>> output = fused_m(input)