如何开启 AMP

AMP全称为Automatic Mixed Precision,即自动混合精度。AMP开启后,pytorch可以自动地在模型执行时将一些算子(如卷积和全连接)使用 float16 进行计算,以达到提升计算速度、减少显存占用的效果。详见 pytorch官方文档

HAT中已经为AMP做好相关的工作,用户只需要在定义config文件中的 batch_processor 字段时将 enable_amp 参数设置为 True 即可。

注解

在模型验证时为得到准确的指标,一般是不需开启AMP的,在定义 val_batch_processor 字段时请将 enable_amp 参数设置为 False

# configs/example.py # 使用 BasicBatchProcessor batch_processor = dict( type='BasicBatchProcessor', need_grad_update=..., batch_transforms=..., enable_amp=True, ) # 使用 MultiBatchProcessor batch_processor = dict( type="MultiBatchProcessor", need_grad_update=..., batch_transforms=..., loss_collector=..., enable_amp=True, )