这篇教程主要是告诉大家如何利用HAT在雷达点云数据集 KITTI-3DObject 上从头开始训练一个 PointPillars 模型,包括浮点、量化和定点模型。
在开始训练模型之前,第一步是需要准备好数据集,我们在KITTI官网下载 3DObject据集 , 包括4个文件:
left color images of object data setvelodyne point cloudscamera calibration matrices of object data settaining labels of object data set下载上述4个文件后,解压并按照如下方式组织文件夹结构:
为了创建KITTI点云数据,首先需要加载原始的点云数据并生成相关的包含目标标签和标注框的数据标注文件,同时还需要为KITTI数据集生成每个单独的训练目标的点云数据,并将其存储在 data/kitti/gt_database 的 .bin 格式的文件中,此外,需要为训练数据或者验证数据生成 .pkl 格式的包含数据信息的文件。随后,通过运行下面的命令来创建KITTI数据:
执行上述命令后,生成的文件目录如下:
同时,为了提升训练的速度,我们对数据信息文件做了一个打包,将其转换成lmdb格式的数据集。只需要运行下面的脚本,就可以成功实现转换:
上面这两条命令分别对应着转换训练数据集和验证数据集,打包完成之后,data目录下的文件结构应该如下所示:
train_lmdb 和 val_lmdb 就是打包之后的训练数据集和验证数据集,也是网络最终读取的数据集,kitti3d_gt_database 和 kitti3d_dbinfos_train.pkl 是训练是用于采样的样本。
数据集准备好之后,就可以开始训练浮点型的PointPillars检测网络了。在网络训练开始之前,你可以使用以下命令先测试一下网络的计算量和参数数量:
如果你只是单纯的想启动这样的训练任务,只需要运行下面的命令就可以:
由于HAT算法包使用了一种巧妙的注册机制,使得每一个训练任务都可以按照这种train.py加上config配置文件的形式启动。
train.py 是统一的训练脚本,与任务无关,我们需要训练什么样的任务、使用什么样的数据集以及训练相关的超参数设置都在指定的config配置文件里面。
config文件里面提供了模型构建、数据读取等关键的dict。
PointPillars 的网络结构可以参考 论文, 这里不做详细介绍。我们通过在config配置文件中定义 model 这样的一个dict型变量,就可以方便的实现对模型的定义和修改。
其中,model 下面的 type 表示定义的模型名称,剩余的变量表示模型的其他组成部分。这样定义模型的好处在于我们可以很方便的替换我们想要的结构。
训练脚本在启动之后,会调用 build_model 接口,将这样一个dict类型的model变成类型为 torch.nn.Module 类型的model。
跟 model 的定义一样,数据增强的流程是通过在config配置文件中定义 data_loader 和 val_data_loader 这两个dict来实现的,分别对应着训练集和验证集的处理流程。以 data_loader 为例:
其中type直接用的pytorch自带的接口 torch.utils.data.DataLoader,表示的是将 batch_size 大小的图片组合到一起。
这里面唯一需要关注的可能是 dataset 这个变量,data_path 路径也就是我们在第一部分数据集准备中提到的路径。transforms 下面包含着一系列的数据增强。val_data_loader 中只有除了点云Pillar化(Voxelization)和Reformat。
你也可以通过在 transforms 中插入新的dict实现自己希望的数据增强操作。
为了训练一个精度高的模型,好的训练策略是必不可少的。对于每一个训练任务而言,相应的训练策略同样都定义在其中的config文件中,从 float_trainer 这个变量就可以看出来。
float_trainer 从大局上定义了我们的训练方式,包括使用多卡分布式训练(distributed_data_parallel_trainer),模型训练的epoch次数,以及优化器的选择。
同时 callbacks 中体现了模型在训练过程中使用到的小策略以及用户想实现的操作,包括学习率的变换方式(CyclicLrUpdater),在训练过程中验证模型的指标(Validation),以及保存(Checkpoint)模型的操作。当然,如果你有自己希望模型在训练过程中实现的操作,也可以按照这种dict的方式添加。
float_trainer 负责将整个训练的逻辑给串联起来,其中也会负责模型的pretrain。
如果需要复现精度,config中的训练策略最好不要修改。否则可能会有意外的训练情况出现。
通过上面的介绍,你应该对config文件的功能有了一个比较清楚的认识。然后通过前面提到的训练脚本,就可以训练一个高精度的纯浮点的检测模型。 当然训练一个好的检测模型不是我们最终的目的,它只是做为一个pretrain为我们后面训练定点模型服务的。
当我们有了纯浮点模型之后,就可以开始训练相应的定点模型了。和浮点训练的方式一样,我们只需要通过运行下面的脚本就可以训练定点模型了:
可以看到,我们的配置文件没有改变,只改变了 stage 的类型。此时我们使用的训练策略来自于config文件中的 qat_trainer 和 calibration_trainer 。
当我们训练量化模型的时候,需要设置 quantize=True ,此时相应的浮点模型会被转换成量化模型,相关代码如下:
关于量化训练中的关键步骤,比如准备浮点模型、算子替换、插入量化和反量化节点、设置量化参数以及算子的融合等,请阅读 量化感知训练 章节的内容。
正如我们之前所说,量化训练其实是在纯浮点训练基础上的finetue。因此量化训练的时候,我们的初始学习率设置为浮点训练的十分之一,训练的epoch次数也大大减少,最重要的是 model 定义的时候,我们的 retrained 需要设置成已经训练出来的纯浮点模型的地址。
做完这些简单的调整之后,就可以开始训练我们的量化模型了。
完成量化训练后,便可以开始导出定点模型。可以通过下面命令来导出:
模型训练完成之后,我们还可以验证训练出来的模型性能。由于我们提供了float、calibration和qat三阶段的训练过程,相应的我们可以验证这三个阶段训练出来的模型性能,只需要相应的运行以下两条命令即可:
定点模型精度验证可使用下面命令,但需要注意是必须要先导出hbir:
这个显示出来的精度才是最终的int8模型的真正精度,当然这个精度和qat验证阶段的精度应该是保持十分接近的。
除了上述模型验证之外,我们还提供和上板完全一致的精度验证方法,可以通过下面的方式完成:
HAT提供了 infer_hbir.py 脚本对各阶段训练好的模型的推理结果进行可视化展示:
在训练完成之后,可以使用 compile_perf_hbir 的工具用来将量化模型编译成可以上板运行的 hbm 文件,同时该工具也能预估在 BPU 上的运行性能,可以采用以下脚本: