数据校准
在量化训练(QAT)中,一个重要的步骤是确定量化参数 scale ,一个合理的 scale 能够显著提升模型训练结果和加快模型的收敛速度。Calibration是通过用浮点模型在训练集上跑少数batch的数据(只跑forward过程,没有backward),统计这些数据的分布直方图,通过一定方法去计算出 min_value 和 max_value ,然后可以用这些 min_value 和 max_value 去获取scale。当QAT的训练精度上不去的时候,在QAT的开始之前使用calibration做量化参数的微调,获取scale,可以为QAT提供更好的量化初始化参数,提升收敛速度和精度。
如何定义 Calibration 模型
-
默认不需要对现有模型做任何修改
类似于定义量化模型时需要设置 QAT QConfig ,Calibration时也需要对模型设置 Calibration QConfig 。不过, Calibration QConfig 的设置相对来说比较简单,HAT框架已经实现对模型 Calibration QConfig 的默认设置,用户无需对模型做任何修改,即可使用Calibration。
-
自定义模型子模块 Calibration QConfig
在上文的默认情况下,会为模型的所有Module(继承自nn.Module)设置 Calibration QConfig 。因此,Calibration时也就会对所有 Module 的特征分布进行统计。如果有特殊需求,可以在模型内自定义实现 set_calibration_qconfig 方法:
class Classifier(nn.Module):
def __init__(self,):
...
def forward(self, x):
...
# 自定义要做 Calibration 的模块
def set_calibration_qconfig(self, ):
# 比如可以设置 Loss 的 qconfig 为 None,就会不再对 Loss 做 Calibration,
# 可以一定程度减少统计量,提升 Calibration 速度,降低显存占用
if self.loss is not None:
self.loss.qconfig = None
浮点模型做 Calibration
HAT中集成了Calibration功能,浮点模型做Calibration命令和正常训练相似,只需执行以下命令即可:
python3 tools/train.py --stage calibration ...
需要注意的是 config 文件中 calibration_trainer 中的一些配置:
# Note: The transforms of the dataset during calibration can be
# consistent with that during training or validation, or customized.
# Default used `val_batch_processor`.
calibration_data_loader = copy.deepcopy(data_loader)
calibration_data_loader.pop('sampler') # Calibration do not support DDP or DP
calibration_batch_processor = copy.deepcopy(val_batch_processor)
calibration_trainer = dict(
type="Calibrator",
model=model,
# 1. 设置 data_loader 和 batch_processor
data_loader=calibration_data_loader,
batch_processor=calibration_batch_processor,
# 2. 设置 calibration 迭代的 batch 数目
num_stages=30,
...
)
1. 数据集的设置:
做Calibration的数据集(dataset)不能是测试集(可以是训练集或其他数据),但是做Calibration时用于数据增强的transforms 可以和正常训练时的transforms保持一致,但是也可以设置成和validation的transforms一致,也可以自定义transforms。(哪种实验效果最好,暂时没有定论,都可以尝试。)
2. Calibration 迭代的图片数目(可供参考):
- classification:图片张数一般可以500~1500张就可以取得不错的效果。
- segmentation&&detection:图片张数可以100~300张左右。
注解
这些图片张数具体数目也不是固定的,上方的建议只是从已有的实验中总结的经验,可根据实际情况调整。
使用Calibration模型做QAT训练
qat_trainer = dict(
type="distributed_data_parallel_trainer",
model=model,
model_convert_pipeline=dict(
type="ModelConvertPipeline",
qat_mode="fuse_bn",
# (可选) 设置 QAT 训练时 scale 更新系数
qconfig_params=dict(
activation_qkwargs=dict(
averaging_constant=0,
),
weight_qkwargs=dict(
averaging_constant=1,
),
),
converters=[
dict(type="Float2QAT"),
dict(
type="LoadCheckpoint",
checkpoint_path=os.path.join(
ckpt_dir, "calibration-checkpoint-best.pth.tar"
),
),
],
),
)
QAT时averaging_constant参数设置:
量化时scale参数的更新规则是 scale = (1 - averaging_constant) * scale + averaging_constant * current_scale 。
在已有的一些实验中(主要是图像分类任务实验)发现,做完calibration后,把activation的scale固定住,不进行更新,即设置activation的 averaging_constant=0 , 并设置weight的 averaging_constant=1 ,效果可能会相对略好一些。
注解
这种设置并不适用于所有任务,在lidar任务中,固定scale,精度也可能会变差。可根据实际情况调整。
接下来只需要执行正常的QAT训练命令,即可启动QAT训练:
python3 tools/train.py --stage qat ...