QCNet轨迹预测模型训练

这篇教程主要是告诉大家如何在数据集 Argoverse 2 上从头开始训练一个 QCNet 模型,包括浮点、量化和定点模型。

训练流程

如果你只是想简单的把 QCNet 的模型训练起来,那么可以首先阅读一下这一章的内容。 和其他任务一样,对于所有的训练,评测任务,HAT统一采用 tools + config 的形式来完成。在准备好原始数据集之后,可以通过下面的流程,方便地完成整个训练的流程。

数据集准备

在开始训练模型之前,第一步是需要准备好数据集,可以在 Argoverse 2 数据集 下载。 需要下载:Training , ValidationTest

下载后,解压并按照如下方式组织文件夹结构:

data |-- argoverse-2 |-- train |-- val |-- test

为了提升训练的速度,我们对数据信息文件做了一个打包,将其转换成lmdb格式的数据集。只需要运行下面的脚本,就可以成功实现转换:

python3 tools/datasets/argoverse2_packer.py --src-data-dir ${data-dir} --split-name train --pack-type lmdb --num-workers 10 --target-data-dir ${target-data-dir} python3 tools/datasets/argoverse2_packer.py --src-data-dir ${data-dir} --split-name val --pack-type lmdb --num-workers 10 --target-data-dir ${target-data-dir}

上面这两条命令分别对应转换训练数据集和验证数据集,打包完成之后,${target-data-dir} 目录下的文件结构应该如下所示:

${target-data-dir} |-- train_lmdb |-- val_lmdb

train_lmdbval_lmdb 就是打包之后的训练数据集和验证数据集,接下来就可以开始训练模型。

模型训练

下一步就可以开始训练。训练也可以通过下面的脚本来完成,在训练之前需要确认配置中数据集路径是否已经切换到已经打包好的数据集路径。

python3 tools/train.py --stage "float" --config configs/traj_pred/qcnet_argoverse2.py python3 tools/train.py --stage "calibration" --config configs/traj_pred/qcnet_argoverse2.py python3 tools/train.py --stage "qat" --config configs/traj_pred/qcnet_argoverse2.py

以上命令分别完成浮点模型和定点模型的训练,其中定点模型的训练需要以训练好的浮点模型为基础,具体内容请阅读 量化感知训练 章节的内容。

导出定点模型

完成量化训练后,便可以开始导出定点模型。可以通过下面命令来导出:

python3 tools/export_hbir.py --config configs/traj_pred/qcnet_argoverse2.py

模型验证

在完成训练之后,可以得到训练完成的浮点和量化模型。和训练方法类似,我们可以用相同方法来对训好的模型做指标验证,得到为 FloatCalibrationQat 的指标, 前者可以得到浮点模型的指标、后两者分别为量化校准和量化训练得到模型的指标。

python3 tools/predict.py --stage "float" --config configs/traj_pred/qcnet_argoverse2.py python3 tools/predict.py --stage "calibration" --config configs/traj_pred/qcnet_argoverse2.py python3 tools/predict.py --stage "qat" --config configs/traj_pred/qcnet_argoverse2.py

定点模型精度验证可使用下面命令,但需要注意是必须要先导出hbir:

python3 tools/predict.py --stage "int_infer" --config configs/traj_pred/qcnet_argoverse2.py

模型推理

HAT 提供了 infer_hbir.py 脚本提供了对定点模型的推理结果进行可视化展示:

python3 tools/infer_hbir.py --config configs/traj_pred/qcnet_argoverse2.py --model-inputs ${model_input} --save-path ${save_path}

仿真上板精度验证

除了上述模型验证之外,我们还提供和上板完全一致的精度验证方法,可以通过下面的方式完成:

python3 tools/validation_hbir.py --stage "align_bpu" --config configs/traj_pred/qcnet_argoverse2.py

定点模型检查和编译

在HAT中集成的量化训练工具链主要是为了地平线的计算平台准备的,因此,对于量化模型的检查和编译是必须的。 我们在HAT中提供了模型检查的接口,可以让用户定义好量化模型之后,先检查能否在 BPU 上正常运行:

python3 tools/model_checker.py --config configs/traj_pred/qcnet_argoverse2.py

在模型训练完成后,可以通过 compile_perf_hbir 脚本将量化模型编译成可以上板运行的 hbm 文件,同时该工具也能预估在 BPU 上的运行性能:

python3 tools/compile_perf_hbir.py --config configs/traj_pred/qcnet_argoverse2.py

以上就是从数据准备到生成量化可部署模型的全过程。

训练细节

在这个说明中,我们对模型训练需要注意的一些事项进行说明,主要为 config 的一些相关设置。

模型构建

QCNet 的网络结构可以参考 论文 ,这里不做详细介绍。 我们通过在 config 配置文件中定义 model 这样的一个 dict 型变量,就可以方便的实现对模型的定义和修改。

model = dict( type="QCNetOE", num_historical_steps=num_historical_steps, input_dim=input_dim, deploy=False, encoder=dict( type="QCNetOEEncoder", map_encoder=dict( type="QCNetOEMapEncoder", input_dim=input_dim, hidden_dim=hidden_dim, num_historical_steps=num_historical_steps, num_freq_bands=num_freq_bands, num_layers=num_map_layers, num_heads=num_heads, head_dim=head_dim, dropout=dropout, deploy=False, ), agent_encoder=dict( type="QCNetOEAgentEncoderStream", input_dim=input_dim, hidden_dim=hidden_dim, num_historical_steps=num_historical_steps, time_span=time_span, num_freq_bands=num_freq_bands, num_layers=num_agent_layers, num_heads=num_heads, head_dim=head_dim, num_pl2a=32, num_a2a=36, dropout=dropout, save_memory=True, stream_infer=False, deploy=False, ), ), decoder=dict( type="QCNetOEDecoder", input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_historical_steps=num_historical_steps, num_future_steps=num_future_steps, num_modes=num_modes, num_recurrent_steps=num_recurrent_steps, num_t2m_steps=num_t2m_steps, num_freq_bands=num_freq_bands, num_layers=num_dec_layers, num_heads=num_heads, head_dim=head_dim, dropout=dropout, split_rec_modules=split_rec_modules, deploy=False, ), loss=dict( type="QCNetOELoss", output_dim=output_dim, num_historical_steps=num_historical_steps, num_future_steps=num_future_steps, ), postprocess=dict( type="Argoverse2Postprocess", output_dim=output_dim, num_historical_steps=num_historical_steps, ), )

数据加载

model 的定义一样,训练和验证阶段的 dataloader 是在 config 配置文件中定义 data_loaderval_data_loader 这两个 dict 来实现的,分别对应着训练集和验证集的处理流程。

QCNet并没有添加复杂的数据增强,因此 transforms=None. collate_fn 定义了如何将单个数据整合为batch, 其中包含一些数据对齐操作。

data_loader = dict( type=torch.utils.data.DataLoader, dataset=dict( type="Argoverse2PackedDataset", data_path=os.path.join(data_rootdir, "train"), split="train", pack_type="lmdb", input_dim=2, transforms=None, ), sampler=dict(type=torch.utils.data.DistributedSampler), batch_size=batch_size_per_gpu * 2, shuffle=True, num_workers=dataloader_workers, pin_memory=True, collate_fn=partial( collate_qc_argoverse2, num_historical_steps=num_historical_steps, stage="train", add_noise=False, ), ) val_data_loader = dict( type=torch.utils.data.DataLoader, dataset=dict( type="Argoverse2PackedDataset", data_path=os.path.join(data_rootdir, "val"), split="val", pack_type="lmdb", input_dim=2, transforms=None, ), sampler=dict(type=torch.utils.data.DistributedSampler), batch_size=batch_size_per_gpu * 2, shuffle=False, num_workers=dataloader_workers, pin_memory=True, collate_fn=partial(collate_qc_argoverse2, stage="val", add_noise=False), )

config 里还定义了 batch_processor 对batch进行处理:

batch_processor = dict( type="MultiBatchProcessor", need_grad_update=True, loss_collector=loss_collector, ) val_batch_processor = dict( type="MultiBatchProcessor", need_grad_update=False, loss_collector=None, )

其中 batch_processor 中传入一个 loss_collector 函数,用于获取当前批量数据的 loss ,如下所示:

def loss_collector(outputs: dict): losses = [] for _, loss in outputs.items(): losses.append(loss) return losses

训练策略

首先介绍QCNet 在 Argoverse2 数据集上训练浮点模型的策略。我们使用 AdamW 作为优化器,设置 lr=5e-4, weight_decay=1e-4,但是不对 biasnn.Embedding 以及归一化层施加 weight decay, 因此此处使用 custom_param_optimizer 能够简单地达到这一设置。我们使用 Cosine 的学习率更新策略,并将 warmup 长度设置为1个epoch,模型共训练64个epoch。 下面为config中 float_trainer 的完整配置示例:

float_trainer = dict( type="distributed_data_parallel_trainer", model=model, data_loader=data_loader, optimizer=dict( type="custom_param_optimizer", optim_cls=torch.optim.AdamW, optim_cfgs=dict(lr=5e-4, weight_decay=1e-4), custom_param_mapper={ "bias": dict(weight_decay=0.0), "norm_types": dict(weight_decay=0.0), nn.Embedding: dict(weight_decay=0.0), }, ), batch_processor=batch_processor, num_epochs=64, device=None, callbacks=[ stat_callback, loss_show_update, dict( type="CosLrUpdater", warmup_len=1, warmup_by="epoch", step_log_interval=500, ), val_callback, ckpt_callback, ], train_metrics=dict( type="LossShow", ), sync_bn=True, val_metrics=val_metrics, )

量化训练

关于量化训练中的关键步骤,比如准备浮点模型、算子替换、插入量化和反量化节点、设置量化参数以及算子的融合等,请阅读 量化感知训练 章节的内容。这里主要讲一下中如何定义和使用量化模型。

QCnet 示例模型的量化训练策略可见 configs/traj_pred/qcnet_argoverse2.py 文件,主要分为量化校准 calibration 和量化训练 qat 两个阶段。

量化校准 calibration 的配置为:

sensitive_path1 = os.path.join(ckpt_dir, "output_prob_L1_sensitive_ops.pt") sensitive_path2 = os.path.join(ckpt_dir, "output_pred_L1_sensitive_ops.pt") if os.path.exists(sensitive_path1): sensitive_table1 = torch.load(sensitive_path1) sensitive_table2 = torch.load(sensitive_path2) cali_qconfig_setter = ( sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter( sensitive_table1, ratio=0.15, ), sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter( sensitive_table2, ratio=0.15, ), default_calibration_qconfig_setter, ) float2calibration = dict( type="Float2Calibration", convert_mode="jit-strip", example_data_loader=calibration_data_loader, qconfig_setter=cali_qconfig_setter, ) calibration_trainer = dict( type="Calibrator", model=cali_model, model_convert_pipeline=dict( type="ModelConvertPipeline", qat_mode=qat_mode, qconfig_params=dict( activation_calibration_observer="mse", ), converters=[ dict( type="LoadCheckpoint", checkpoint_path=os.path.join( ckpt_dir, "float-checkpoint-last.pth.tar" ), allow_miss=True, ignore_extra=True, verbose=True, ), float2calibration, ], ), data_loader=calibration_data_loader, batch_processor=calibration_batch_processor, num_steps=calibration_step, device=None, callbacks=[ ckpt_callback, val_callback, ], log_interval=calibration_step / 10, val_metrics=val_metrics, )

量化训练 qat 的配置为:

qat_trainer = dict( type="distributed_data_parallel_trainer", model=cali_model, model_convert_pipeline=dict( type="ModelConvertPipeline", qat_mode=qat_mode, converters=[ float2qat, dict( type="LoadCheckpoint", checkpoint_path=os.path.join( ckpt_dir, "calibration-checkpoint-last.pth.tar" ), verbose=True, ), ], ), data_loader=qat_data_loader, optimizer=dict( type="custom_param_optimizer", optim_cls=torch.optim.AdamW, optim_cfgs=dict(lr=5e-6, weight_decay=1e-4), custom_param_mapper={ "bias": dict(weight_decay=0.0), "norm_types": dict(weight_decay=0.0), nn.Embedding: dict(weight_decay=0.0), }, ), batch_processor=batch_processor, num_epochs=2, device=None, callbacks=[ stat_callback, loss_show_update, dict( type="StepDecayLrUpdater", lr_decay_id=[1], step_log_interval=500, ), val_callback, ckpt_callback, ], train_metrics=dict( type="LossShow", ), val_metrics=val_metrics, )

其中 float2calibrationfloat2qat 分别定义了浮点到校准模型以及浮点到量化训练模型的模型转换过程; cali_qconfig_setterqat_qconfig_setter 分别为calibration模型和qat模型对应的qconfig设置,关于qconfig的设置方法与调试步骤,比如默认qconfig的设置、量化敏感算子设置等,请阅读 量化感知训练-Qconfig详解 章节的内容。

量化敏感度算子排序

量化训练过程中,我们需要将某些量化敏感的算子设置为int16,以满足模型的量化精度需求。量化敏感算子的排序可以通过运行以下命令获得。

python3 tools/quant_analysis.py --config configs/traj_pred/qcnet_argoverse2.py

其中的关键步骤,请阅读 量化感知训练-精度调优工具使用指南 章节的内容。