DenseTNT轨迹预测模型训练

这篇教程主要是告诉大家如何在数据集 Argoverse 1.1 上从头开始训练一个 DenseTNT 模型,包括浮点、量化和定点模型。

训练流程

如果你只是想简单的把 Bev 的模型训练起来,那么可以首先阅读一下这一章的内容。 和其他任务一样,对于所有的训练,评测任务,HAT统一采用 tools + config 的形式来完成。在准备好原始数据集之后,可以通过下面的流程,方便地完成整个训练的流程。

数据集准备

在开始训练模型之前,第一步是需要准备好数据集,可以在 Argoverse 1 数据集 下载。 需要下载:TrainingValidation。 同时需要准备 HD Map数据

下载后,解压并按照如下方式组织文件夹结构:

data |-- argoverse-1 |-- train |-- Argoverse-Terms_of_Use.txt |-- data |-- val |-- Argoverse-Terms_of_Use.txt |-- data |-- map_files

为了提升训练的速度,我们对数据信息文件做了一个打包,将其转换成lmdb格式的数据集。只需要运行下面的脚本,就可以成功实现转换:

python3 tools/datasets/argoverse_packer.py --src-data-dir ${data-dir} --mode train --pack-type lmdb --num-workers 10 --target-data-dir ${target-data-dir} python3 tools/datasets/argoverse_packer.py --src-data-dir ${data-dir} --mode val --pack-type lmdb --num-workers 10 --target-data-dir ${target-data-dir}

上面这两条命令分别对应转换训练数据集和验证数据集,打包完成之后,${target-data-dir} 目录下的文件结构应该如下所示:

${target-data-dir} |-- train_lmdb |-- val_lmdb

train_lmdbval_lmdb 就是打包之后的训练数据集和验证数据集,接下来就可以开始训练模型。

模型训练

下一步就可以开始训练。训练也可以通过下面的脚本来完成,在训练之前需要确认配置中数据集路径是否已经切换到已经打包好的数据集路径。

python3 tools/train.py --stage "float" --config configs/traj_pred/densetnt_vectornet_argoverse1.py python3 tools/train.py --stage "calibration" --config configs/traj_pred/densetnt_vectornet_argoverse1.py python3 tools/train.py --stage "qat" --config configs/traj_pred/densetnt_vectornet_argoverse1.py

以上命令分别完成浮点模型和定点模型的训练,其中定点模型的训练需要以训练好的浮点模型为基础,具体内容请阅读 量化感知训练 章节的内容。

导出定点模型

完成量化训练后,便可以开始导出定点模型。可以通过下面命令来导出:

python3 tools/export_hbir.py --config configs/traj_pred/densetnt_vectornet_argoverse1.py

模型验证

在完成训练之后,可以得到训练完成的浮点、量化或定点模型。和训练方法类似,我们可以用相同方法来对训好的模型做指标验证,得到为 FloatCalibrationQuantized 的指标,分别为浮点、量化和完全定点的指标。

python3 tools/predict.py --stage "float" --config configs/traj_pred/densetnt_vectornet_argoverse1.py python3 tools/predict.py --stage "calibration" --config configs/traj_pred/densetnt_vectornet_argoverse1.py

和训练模型时类似, --stage 后面的参数为 "float""calibration" 时,分别可以完成对训练好的浮点模型、量化模型的验证。

定点模型精度验证可使用下面命令,但需要注意是必须要先导出hbir:

python3 tools/predict.py --stage "int_infer" --config configs/traj_pred/densetnt_vectornet_argoverse1.py

模型推理

HAT 提供了 infer_hbir.py 脚本提供了对定点模型的推理结果进行可视化展示:

python3 tools/infer_hbir.py --config configs/bev/configs/traj_pred/densetnt_vectornet_argoverse1.py --model-inputs img:${img-path} --save-path ${save_path}

仿真上板精度验证

除了上述模型验证之外,我们还提供和上板完全一致的精度验证方法,可以通过下面的方式完成:

python3 tools/validation_hbir.py --stage "align_bpu" --config configs/traj_pred/densetnt_vectornet_argoverse1.py

定点模型检查和编译

在HAT中集成的量化训练工具链主要是为了地平线的计算平台准备的,因此,对于量化模型的检查和编译是必须的。 我们在HAT中提供了模型检查的接口,可以让用户定义好量化模型之后,先检查能否在 BPU 上正常运行:

python3 tools/model_checker.py --config configs/traj_pred/densetnt_vectornet_argoverse1.py

在模型训练完成后,可以通过 compile_perf_hbir 脚本将量化模型编译成可以上板运行的 hbm 文件,同时该工具也能预估在 BPU 上的运行性能:

python3 tools/compile_perf_hbir.py --config configs/traj_pred/densetnt_vectornet_argoverse1.py

以上就是从数据准备到生成量化可部署模型的全过程。

训练细节

在这个说明中,我们对模型训练需要注意的一些事项进行说明,主要为 config 的一些相关设置。

模型构建

DenseTNT 的网络结构可以参考 论文 ,这里不做详细介绍。 我们通过在 config 配置文件中定义 model 这样的一个 dict 型变量,就可以方便的实现对模型的定义和修改。

model = dict( type="MotionForecasting", encoder=dict( type="Vectornet", depth = 3, traj_in_channels = 9, traj_num_vec = 9, lane_in_channels = 11, lane_num_vec = 19, hidden_size = 128, ), decoder=dict( type="Densetnt", in_channels = 128, hidden_size = 128, num_traj = 32, target_graph_depth = 2, pred_steps = 30, top_k = 150, ), target=dict( type="DensetntTarget", ), loss=dict( type="DensetntLoss", ), postprocess=dict( type="DensetntPostprocess", threshold=2.0, pred_steps = 30, mode_num = 6 ), )

数据加载

model 的定义一样,数据增强的流程是通过在 config 配置文件中定义 data_loaderval_data_loader 这两个 dict 来实现的,分别对应着训练集和验证集的处理流程。

data_loader = dict( type=torch.utils.data.DataLoader, dataset=dict( type="Argoverse1Dataset", data_path = os.path.join(data_rootdir, "train_lmdb"), map_path = map_path, pred_step = 20, max_distance = 50.0, max_lane_num = 64, max_traj_num = 32, max_goals_num = 2048, ), sampler=dict(type=torch.utils.data.DistributedSampler), batch_size=batch_size_per_gpu, shuffle=True, num_workers=dataloader_workers, pin_memory=True, collate_fn=hat.data.collates.collate_argoverse, )

batch_processor 中传入一个 loss_collector 函数,用于获取当前批量数据的 loss ,如下所示:

def loss_collector(outputs: dict): losses = [] for _, loss in outputs.items(): losses.append(loss) return losses

验证集的数据转换相对简单很多,如下所示:

val_data_loader = dict( type=torch.utils.data.DataLoader, dataset=dict( type="Argoverse1Dataset", data_path = os.path.join(data_rootdir, "val_lmdb"), map_path = map_path, pred_step = 20, max_distance = 50.0, max_lane_num = 64, max_traj_num = 32, max_goals_num = 2048, ), sampler=dict(type=torch.utils.data.DistributedSampler), batch_size=batch_size_per_gpu, shuffle=True, num_workers=dataloader_workers, pin_memory=True, collate_fn=hat.data.collates.collate_argoverse, )
val_batch_processor = dict( type="MultiBatchProcessor", need_grad_update=False, loss_collector=None, )

训练策略

SceneFlow 数据集上训练浮点模型使用 Cosine 的学习策略配合 Warmup, 以及对 weight 的参数施加 L2 norm。 configs/traj_pred/densetnt_vectornet_argoverse1.py 文件中的 float_trainer, qat_trainer, int_trainer 分别对应浮点、量化、定点模型的训练策略。 下面为 float_trainer 训练策略示例:

float_trainer = dict( type="distributed_data_parallel_trainer", model=model, data_loader=data_loader, optimizer=dict( type=torch.optim.AdamW, eps = 1e-8, betas = (0.9, 0.999), lr = 1e-3, weight_decay = 0.01, ), batch_processor=batch_processor, num_epochs = 30, device = None, callbacks=[ stat_callback, loss_show_update, dict( type="CosLrUpdater", warmup_len = 1, warmup_by = "epoch", step_log_interval = 500, ), val_callback, ckpt_callback, ], train_metrics=dict( type="LossShow", ), sync_bn=True, val_metrics=[val_nuscenes_metric], )

量化训练

关于量化训练中的关键步骤,比如准备浮点模型、算子替换、插入量化和反量化节点、设置量化参数以及算子的融合等,请阅读 量化感知训练 章节的内容。这里主要讲一下中如何定义和使用量化模型。

在模型准备的好情况下,包括量化已有的一些模块完成之后,在训练脚本中统一使用下面的脚本将浮点模型映射到定点模型上来。

model.set_qconfig() horizon.quantization.prepare_qat_fx(model)

量化训练的整体策略可以直接沿用浮点训练的策略,但学习率和训练长度需要适当调整。 因为有浮点预训练模型,所以量化训练的学习率 Lr 可以很小, 一般可以从 0.001 或 0.0001 开始,并可以搭配 StepLrUpdater 做 1-2 次 scale=0.1Lr 调整; 同时训练的长度不用很长。此外 weight decay 也会对训练结果有一定影响。

DenseTNT 示例模型的量化训练策略可见 configs/traj_pred/densetnt_vectornet_argoverse1.py 文件。