校准数据集准备

注解

如果本过程,您需在示例文件夹内进行,那么您需要先执行文件夹中的 00_init.sh 脚本以获取对应的原始模型和数据集。

在进行模型校准时,需要20~100份的标定样本输入,每一份样本都是一个独立的数据文件。为了确保校准后模型的精度效果,我们希望这些校准样本来自于您训练模型使用的训练集或验证集,不要使用非常少见的异常样本,例如纯色图片、不含任何检测或分类目标的图片等。

您需要把取自训练集/验证集的样本进行前处理(前处理过程与原始浮点模型数据前处理过程一致),处理完后的校准样本会与原始浮点模型具备一样的数据类型(input_type_train)、尺寸(input_shape) 和 layout(input_layout_train), 您可以通过 numpy.save 命令将数据保存为npy文件,工具链校准时会基于 numpy.load 命令进行读取。例如,有一个使用ImageNet训练的用于分类的原始浮点模型,它只有一个输入节点,输入信息描述如下:

  • 输入类型:BGR

  • 输入layout:NCHW

  • 输入尺寸:1x3x224x224

原始浮点模型进行数据前处理时的步骤如下:

  1. 图像长宽等比scale,短边缩放到256。

  2. center_crop 方法获取224x224大小图像。

  3. 对齐输入layout为模型所需的 NCHW

  4. 转换色彩空间为模型所需的 BGR

  5. 图像数值范围调整为模型所需的[0, 255]。

  6. 按通道减mean。

  7. 数据乘以scale系数。

针对上述举例模型的样本处理代码如下(为避免过长代码篇幅,各种简单transformer实现代码未贴出,transformer使用方法可参考 图片处理transformer说明 ):

# 本示例使用skimage,如果是opencv/PIL会有所区别 import skimage import skimage.io import numpy as np from horizon_tc_ui.data.transformer import (CenterCropTransformer, HWC2CHWTransformer, MeanTransformer, RGB2BGRTransformer, ScaleTransformer, ShortSideResizeTransformer) def data_transformer(): transformers = [ # 长宽等比scale,短边缩放至256 ShortSideResizeTransformer(short_size=256), # CenterCrop获取224x224图像 CenterCropTransformer(crop_size=224), # skimage读取结果为NHWC排布,转换为模型需要的NCHW HWC2CHWTransformer(), # skimage读取结果通道顺序为RGB,转换为模型需要的BGR RGB2BGRTransformer(), # skimage读取数值范围为[0.0,1.0],调整为模型需要的数值范围 ScaleTransformer(scale_value=255), # 对输入图片中的所有像素值做减去 mean_value MeanTransformer(means=np.array([103.94, 116.78, 123.68])), # 对输入图片中的所有像素值做乘以data_scale系数 ScaleTransformer(scale_value=0.017) ] return transformers # src_image 标定集中的原图片 # dst_file 存放最终标定样本数据的文件名称 def convert_image(src_image, dst_file, transformers): image = [skimage.img_as_float( skimage.io.imread(src_image)).astype(np.float32)] for trans in transformers: image = trans(image) # 模型指定的input_type_train BGR数值类型是UINT8 image = image[0].astype(np.uint8) # 以二进制形式存储标定样本到数据文件 np.save(dst_file, image) if __name__ == '__main__': # 此处表示原始标定图片集合,伪代码 src_images = ['ILSVRC2012_val_00000001.JPEG', ...] # 此处表示最终标定文件名称(后缀名不限制),伪代码 # calibration_data_bgr是您在配置文件中指定的cal_data_dir dst_files = ['./calibration_data_bgr/ILSVRC2012_val_00000001.npy', ...] transformers = data_transformer() for src_image, dst_file in zip(src_images, dst_files): convert_image(src_image, dst_file, transformers)
注意

请注意,yaml文件中input_shape参数作用为指定原始浮点模型的输入数据尺寸。若为动态输入模型则可通过这个参数设置转换后的输入大小,而校准数据的shape大小应与input_shape保持一致。

例如:若原始浮点模型输入节点shape为?x3x224x224("?"号代表占位符,即该模型第一维为动态输入),转换配置文件中设置input_shape: 8x3x224x224,则需要准备的每份校准数据大小为8x3x224x224。(请知悉,此类输入shape第一维不等于1的模型,不支持通过input_batch参数修改模型batch信息。)