PwcNet Optical Flow Prediction Model Training

This tutorial focuses on how to train a PwcNet model from scratch on the optical flow dataset FlyingChairs using HAT, including floating-point, quantitative, and fixed-point models.

FlyingChairs is the most used dataset in optical flow prediction, and many state-of-the-art optical flow prediction studies are primarily based on this dataset for validation.

Before starting the model training, the first step is to prepare the dataset. Here we use the official dataset FlyingChairs.zip as the training and validation sets. Meanwhile, we need corresponding label data FlyingChairs_train_val.txt.

The extracted directory structure is as follows:

tmp_data |-- FlyingChairs |-- FlyingChairs_release |-- data |-- README.txt |-- FlyingChairs_train_val.txt |-- FlyingChairs.zip

Training Process

If you just want to train the PwcNet model, you can read this section first. Similar to other tasks, HAT uses the tools + config format for all the training and evaluation tasks. After preparing the raw dataset, we can easily complete the training process by taking the following procedure.

Dataset Preparation

To improve the training speed, we packed the original dataset and converted it to the LMDB format. The conversion can be done by running the following script:

python3 tools/datasets/flyingchairs_packer.py --src-data-dir ${data-dir} --split-name train --pack-type lmdb --num-workers 10 python3 tools/datasets/flyingchairs_packer.py --src-data-dir ${data-dir} --split-name val --pack-type lmdb --num-workers 10

The above two commands are for the transformation of the training dataset and the validation dataset respectively. After the packing, the file structure in the directory should be as follows:

tmp_data |-- FlyingChairs |-- FlyingChairs_release |-- data |-- README.txt |-- FlyingChairs_train_val.txt |-- FlyingChairs.zip |-- train_lmdb |-- val_lmdb

train_lmdb and val_lmdb are the packed training dataset and validation dataset. Next you can start training the model.

Model Training

Before the network training starts, you can first test the amount of operations and parameters of the network using the following commands:

python3 tools/calops.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

The next step is to start training. Training can also be done with the following script. Before the training, you need to make sure that the dataset path in the configuration has already been changed to the packed dataset path.

python3 tools/train.py --stage "float" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py python3 tools/train.py --stage "calibration" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py python3 tools/train.py --stage "qat" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

Since the HAT algorithm toolkit uses an ingenious registration mechanism, each training task can be started in the form of train.py plus a config file.

train.py is a uniform training script and independent of the task. The task we need to train, the dataset we need to use, and the hyperparameters we need to set for the training are all in the specified config file.

The parameters after --stage in the above command can be "float", "calibration", "qat" to train the floating-point model, the quantitative model, and the conversion of the quantitative model to the fixed-point model, where the training of the quantitative model depends on the floating-point model produced by the previous floating-point training.

Export FixedPoint Model

Once you've completed your quantization training, you can start exporting your fixed-point model. You can export it with the following command:

python3 tools/export_hbir.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

Model Validation

After completing the training, we get the trained floating-point, quantitative, or fixed-point model. Similar to the training method, we can complete metrics validation on the trained model in the same way and get the metrics of Float, Calibration, QAT, and Quantized, which are floating-point, quantitative, and fully fixed-point metrics, respectively.

python3 tools/predict.py --stage "float" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py python3 tools/predict.py --stage "calibration" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py python3 tools/predict.py --stage "qat" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

Similar to model training, --stage followed by "float", "calibration", "qat" can be used to validate the trained floating-point model, quantitative model, respectively.

The following command can be used to verify the accuracy of a fixed-point model, but it should be noted that hbir must be exported first:

python3 tools/predict.py --stage "int_infer" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

Model Inference

HAT provides the infer_hbir.py script to visualize the inference results of the fixed-point models:

python3 tools/infer_hbir.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py --model-inputs img1:${img1-path},img2:${img2-path} --save-path ${save_path}

Simulation On-board Accuracy Validation

In addition to the above model validation, we provide accuracy validation method exactly the same as the on-board environment, which can be done by the following:

python3 tools/validation_hbir.py --stage "align_bpu" --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

Fixed-point Model Check and Compilation

As the quantitative training toolchain integrated in HAT is mainly prepared for Horizon processors, it is a must to check and compile the quantitative model.

We provide an interface for model checking in the training script and you can first define a quantitative model and then check whether it works properly on the BPU:

python3 tools/model_checker.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

After the model is trained, use the compile_perf_hbir script to compile the quantitative model into an HBM file that supports on-board running. This tool can also be used to predict the model performance on the BPU.

python3 tools/compile_perf_hbir.py --config configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py

The above is the whole process from data preparation to generating a quantitative and deployable model.

Training Details

We will illustrate some things that need to be aware of for model training, which mainly includes config related settings.

Model Building

The network structure of PwcNet can be found in the Paper and Community TensorFlow Version and here we will kip the details.

We can easily define and modify the model by defining a dict type variable like model in the config file.

from torch import nn loss_weights = [0.005, 0.01, 0.02, 0.08, 0.32] out_channels = [16, 32, 64, 96, 128, 196] flow_pred_lvl = 2 pyr_lvls = 6 use_bn = True bn_kwargs = {} use_res = True use_dense = True model = dict( type="PwcnetTask", backbone=dict( type="PwcNet", out_channels=out_channels, use_bn=use_bn, bn_kwargs=bn_kwargs, pyr_lvls=pyr_lvls, flow_pred_lvl=flow_pred_lvl, act_type=nn.ReLU(), ), head=dict( type="PwcnetHead", in_channels=out_channels, bn_kwargs=bn_kwargs, use_bn=use_bn, md=4, use_res=use_res, use_dense=use_dense, pyr_lvls=pyr_lvls, flow_pred_lvl=flow_pred_lvl, act_type=nn.ReLU(), ), loss=dict(type="LnNormLoss", norm_order=2, power=1, reduction="mean"), loss_weights=loss_weights, )

In addition to backbone, the model also has head and losses modules. In PwcNet, backbone is mainly used to extract the features of two images, where head is mainly used to get the predicted optical flow map from the features while losses samples LnNormLoss from the paper as the training loss. loss_weights represents the loss weight.

Data Enhancement

Like the definition of model, the data enhancement process is implemented by defining two dicts, data_loader and val_data_loader, in the config file, corresponding to the processing of the training and validation sets, respectively.

Taking data_loader as an example, the data enhancement uses RandomCrop, RandomFlip, SegRandomAffine, and FlowRandomAffineScale.

data_loader = dict( type=torch.utils.data.DataLoader, dataset=dict( type="FlyingChairs", data_path="./tmp_data/FlyingChairs/train_lmdb/", transforms=[ dict( type="RandomCrop", size=(256, 448), ), dict( type="RandomFlip", px=0.5, py=0.5, ), dict( type="ToTensor", to_yuv=False, ), dict( type="SegRandomAffine", degrees=0, translate=(0.05, 0.05), scale=(1.0, 1.0), interpolation=InterpolationMode.BILINEAR, label_fill_value=0, translate_p=0.5, scale_p=0.0, ), dict( type="FlowRandomAffineScale", scale_p=0.5, scale_r=0.05, ), ], to_rgb=True, ), sampler=dict(type=torch.utils.data.DistributedSampler), batch_size=batch_size_per_gpu, pin_memory=True, shuffle=True, num_workers=4, collate_fn=hat.data.collates.collate_2d, )

Since the final model running on the BPU uses YUV444 as image input, while training image input is generally in the RGB format, HAT provides BgrToYuv444 data enhancement to convert RGB to YUV444.

To optimize the training process, some enhancement can be processed in batch_processor to optimize the training.

def loss_collector(outputs: dict): return outputs["losses"] batch_processor = dict( type="MultiBatchProcessor", need_grad_update=True, batch_transforms=[ dict(type="BgrToYuv444", rgb_input=True), dict( type="TorchVisionAdapter", interface="Normalize", mean=128.0, std=128.0, ), dict( type="Scale", scales=tuple(1 / np.array(train_scales)), mode="bilinear", ), ], loss_collector=loss_collector, )

In which loss_collector is the function to get the loss of the current batch data.

The data conversion for the validation set is relatively simpler, as follows:

val_data_loader = dict( type=torch.utils.data.DataLoader, dataset=dict( type="FlyingChairs", data_path="./tmp_data/FlyingChairs/val_lmdb/", transforms=[ dict( type="ToTensor", to_yuv=False, ), ], to_rgb=True, ), batch_size=batch_size_per_gpu, shuffle=False, num_workers=data_num_workers, pin_memory=True, collate_fn=hat.data.collates.collate_2d, )
val_batch_processor = dict( type="MultiBatchProcessor", need_grad_update=False, batch_transforms=[ dict(type="BgrToYuv444", rgb_input=True), dict( type="TorchVisionAdapter", interface="Normalize", mean=128.0, std=128.0, ), ], loss_collector=None, )

Training Strategy

The floating-point model is trained on the FlyingChairs dataset using the Cosine learning strategy with Warmup and imposing L2 norm on the parameter weight.

The float_trainer, calibration_trainer, qat_trainer, and int_trainer in the file configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py refer to the training strategies for floating-point, quantitative, and fixed-point models, respectively.

Take the training strategy of float_trainer as an example:

float_trainer = dict( type="distributed_data_parallel_trainer", model=model, data_loader=data_loader, optimizer=dict( type=torch.optim.Adam, params={"weight": dict(weight_decay=4e-4)}, lr=lr, ), batch_processor=batch_processor, stop_by="epoch", num_epochs=max_epoch, device=None, callbacks=[ stat_callback, loss_metirc_show_update, dict( type="CosLrUpdater", warmup_by="epoch", warmup_len=10, step_log_interval=1000, ), val_callback, ckpt_callback, ], train_metrics=[ dict(type="LossShow"), dict(type="EndPointError"), ], val_metrics=[ dict(type="EndPointError"), ], sync_bn=True, )

Quantitative Training

For key steps in quantitative training, e.g., preparing the floating-point model, operator substitution, inserting quantization and inverse quantitative nodes, setting quantitative parameters, and operator fusion, etc., please read the Quantized Awareness Training (QAT) section.

When the model is ready and some existing modules are quantized, HAT uses the following script in the training script to map the floating-point model to the fixed-point model.

model.fuse_model() model.set_qconfig() horizon.quantization.prepare_qat(model, inplace=True)

The overall strategy of quantitative training can directly follow the strategy of floating-point training, but the learning rate and training length need to be adjusted appropriately.

Because there is a floating-point pre-training model, the learning rate Lr of quantitative training can be rather small, usually starting from 0.001 or 0.0001, and can perform Lr adjustments of scale=0.1 for 1 or 2 times with StepLrUpdater; at the same time, the training length does not need to be long.

In addition, weight decay will also have some influence on the training results.

The quantitative training strategy for the sample model of PwcNet can be found in the configs/opticalflow_pred/pwcnet/pwcnet_pwcnetneck_flyingchairs.py file.