FX Quantization Principle

Before reading this section, it is recommended to read the torch.fx — PyTorch documentation to get an initial understanding of torch's FX mechanism.

FX uses the symbolic execution approach, which allows models to be graphically constructed at the nn.Module or function level, allowing for automated fuse and other graph-based optimizations.

Quantized Process

Fuse (Optional)

FX can sense the computational graph, so you can automate the fusion of operators. You no longer need to manually specify the operators to be fused, just call the interface directly.

fused_model = horizon.quantization.fuse_fx(model)
  • Note that there is no inplace parameter for fuse_fx, because internally you need to do symbolic trace on the model to generate a GraphModule, so you can't do inplace modification.
  • The fused_model and model will share almost all attributes (including submodules, operators, etc.), so do not make any changes to the model after fuse, as this may affect the fused_model.
  • It is not necessary to explicitly call the fuse_fx interface, as the subsequent prepare interface integrates the fuse procedure internally.

Prepare

The global march must be set according to the target hardware platform before calling the prepare interface, which internally performs a fuse procedure (even if the model has already been fused) and then replaces the eligible operators in the model with implementations from horizon.nn.qat.

  • You can choose the appropriate qconfig (Calibtaion or QAT, note that you can't mix the two qconfigs) if you want.
  • Similar to fuse_fx, this interface does not support the inplace parameter, and please do not make any changes to the input model after prepare.
horizon.march.set_march(horizon.march.March.NASH) qat_model = horizon.quantization.prepare( model, qconfig_setter = horizon.quantization.qconfig_template.default_qat_qconfig_setter, method = horizon.quantization.PrepareMethod.SYMBOLIC, )

Eager Mode Compatibility

In most cases, FX quantized interfaces can directly replace eager mode quantized interfaces (prepare_qat -> prepare), but they cannot be mixed with eager mode interfaces. Some models require some modifications to the code structure in the following cases.

  • Operations not supported by FX: the symbolic trace of torch supports limited operations, e.g., it does not support non-static variables as judgment conditions, it does not support pkgs (e.g., numpy) other than torch by default, etc., and unexecuted conditional branches will be discarded.
  • Operations that don't want to be processed by FX: If torch's ops are used in the model's pre- and post-processing, FX will treat them as part of the model when tracing, producing undesired behavior (e.g., replacing some of torch's function calls with FloatFunctional).

Both of these cases can be avoided by using wrap, as illustrated by RetinaNet.

from horizon_plugin_pytorch.fx.fx_helper import wrap as fx_wrap class RetinaNet(nn.Module): def __init__( self, backbone: nn.Module, neck: Optional[nn.Module] = None, head: Optional[nn.Module] = None, anchors: Optional[nn.Module] = None, targets: Optional[nn.Module] = None, post_process: Optional[nn.Module] = None, loss_cls: Optional[nn.Module] = None, loss_reg: Optional[nn.Module] = None, ): super(RetinaNet, self).__init__() self.backbone = backbone self.neck = neck self.head = head self.anchors = anchors self.targets = targets self.post_process = post_process self.loss_cls = loss_cls self.loss_reg = loss_reg def rearrange_head_out(self, inputs: List[torch.Tensor], num: int): outputs = [] for t in inputs: outputs.append(t.permute(0, 2, 3, 1).reshape(t.shape[0], -1, num)) return torch.cat(outputs, dim=1) def forward(self, data: Dict): feat = self.backbone(data["img"]) feat = self.neck(feat) if self.neck else feat cls_scores, bbox_preds = self.head(feat) if self.post_process is None: return cls_scores, bbox_preds # by encapsulating operations that don't need to build a graph into a method, # FX will no longer care about the logic inside the method, it will just leave it as is # (the module called in the method can still be set to qconfig and replaced by prepare_qat_fx and convert_fx) return self._post_process( data, feat, cls_scores, bbox_preds) @fx_wrap() # fx_wrap supports direct decoration of class method def _post_process(self, data, feat, cls_scores, bbox_preds) anchors = self.anchors(feat) # the judgment of self.training must be encapsulated, otherwise it will be thrown away after the symbolic trace. if self.training: cls_scores = self.rearrange_head_out( cls_scores, self.head.num_classes ) bbox_preds = self.rearrange_head_out(bbox_preds, 4) gt_labels = [ torch.cat( [data["gt_bboxes"][i], data["gt_classes"][i][:, None] + 1], dim=-1, ) for i in range(len(data["gt_classes"])) ] gt_labels = [gt_label.float() for gt_label in gt_labels] _, labels = self.targets(anchors, gt_labels) avg_factor = labels["reg_label_mask"].sum() if avg_factor == 0: avg_factor += 1 cls_loss = self.loss_cls( pred=cls_scores.sigmoid(), target=labels["cls_label"], weight=labels["cls_label_mask"], avg_factor=avg_factor, ) reg_loss = self.loss_reg( pred=bbox_preds, target=labels["reg_label"], weight=labels["reg_label_mask"], avg_factor=avg_factor, ) return { "cls_loss": cls_loss, "reg_loss": reg_loss, } else: preds = self.post_process( anchors, cls_scores, bbox_preds, [torch.tensor(shape) for shape in data["resized_shape"]], ) assert ( "pred_bboxes" not in data.keys() ), "pred_bboxes has been in data.keys()" data["pred_bboxes"] = preds return data