To help you better understand the config file, this section takes the FCOS-EfficientNetB0 model as an example, and adds a brief comment to each of its modules for your reference, as follows:
VERSION = ConfigVersion.v2 # Version number of the config file. Default: v2
training_step = os.environ.get("HAT_TRAINING_STEP", "float") # Stage that user wants to train, usually set by --stage in the command line
task_name = "fcos_efficientnetb0_mscoco" # Name for the current training task
num_classes = 80 # Category of the dataset involved in the training
batch_size_per_gpu = 24 # Batch_size of each device
device_ids = [0, 1, 2, 3] # ID of the GPUs participated in the training
ckpt_dir = "./tmp_models/%s" % task_name # Storage path of the model
cudnn_benchmark = True # Whether to set torch.backends.cudnn.benchmark=True
seed = None # Whether to set random seed
log_rank_zero_only = True # Whether to print the log information only on device 0
bn_kwargs = {} # Parameters of bn, {} means to use the default parameters of Torch
march = March.NASH_E # Architecture of the computing platform where the model is finally deployed. Default: March.NASH_E
# Configuration of the model participated in the training
model = dict(
type="FCOS", # Type of detection model, here FCOS detection model is used
backbone=dict( # Backbone configurations of the detection model
type="efficientnet", # Model used by the backbone, EfficientNet model is used here
bn_kwargs=bn_kwargs, # Bn configuration of the backbone
model_type="b0", # B0 structure from the EfficientNet model family
num_classes=1000, # Class when EfficientNet as classification model. When used here as the backbone of the detection model, num_classes does not actually play a role
include_top=False, # Whether to include the classification layer of EfficientNet, because EfficientNet extracts features as a backbone, so no classification layer is required
activation="relu", # Activation layer of the backbone, here relu is used
use_se_block=False, # Whether the backbone uses the se_block module. Default: False
), # PS: For a detailed explanation of each parameter configuration of the backbone, refer to the API documentation of EfficientNet
neck=dict( # Neck configurations of the detection model
type="BiFPN", # BiFPN is used by neck
in_strides=[2, 4, 8, 16, 32], # Stride corresponding to the input feature
out_strides=[8, 16, 32, 64, 128], # Stride corresponding to the output feature
stride2channels=dict({2: 16, 4: 24, 8: 40, 16: 112, 32: 320}), # Correspondence between the stride and channel of the input feature
out_channels=64, # Channel of the output features
num_outs=5, # Number of output features
stack=3, # Number of BifpnLayer layers
start_level=2, # Index of the first output feature of the backbone
end_level=-1, # Index of the last output feature of the backbone
fpn_name="bifpn_sum", # FPN name, related to the way the weight is initialized
), # PS: For a detailed explanation of each parameter configuration of the neck, refer to the API documentation of BiFPN
head=dict( # The head configurations of the detection model
type="FCOSHead", # FCOSHead is used by head
num_classes=num_classes, # Class of the detection dataset
in_strides=[8, 16, 32, 64, 128], # Stride corresponding to the input feature
out_strides=[8, 16, 32, 64, 128], # Stride corresponding to the output feature
stride2channels=dict({8: 64, 16: 64, 32: 64, 64: 64, 128: 64}), # Correspondence between the stride and channel of the input feature
upscale_bbox_pred=False, # Whether need upscale bbox_pred
feat_channels=64, # Channel of input features
stacked_convs=4, # Number of consecutive conv
int8_output=False, # Whether the output is set to int8
int16_output=True, # Whether the output is set to int16
dequant_output=True, # Whether the output needs to be dequantized
), # PS: For a detailed explanation of each parameter configuration of the head, refer to the API documentation of FCOSHead
targets=dict( # Target configurations of the detection model
type="DynamicFcosTarget", # DynamicFcosTarget is used by target
strides=[8, 16, 32, 64, 128], # Stride corresponding to the input feature
cls_out_channels=80, # Number of classes
background_label=80, # Class label for the background
topK=10, # Max number of positive samples retained for each ground truth
loss_cls=dict( # Classification loss function settings for dynamically generating targets
type="FocalLoss", # FocalLoss function is used
loss_name="cls",
num_classes=80 + 1,
alpha=0.25,
gamma=2.0,
loss_weight=1.0,
reduction="none",
), # PS: For a detailed explanation of parameter configurations, refer to the API documentation of FocalLoss
loss_reg=dict( # Loss function setting for regression, GIoULoss loss function is used
type="GIoULoss", loss_name="reg", loss_weight=2.0, reduction="none"
), # PS: For a detailed explanation of parameter configurations, refer to the API documentation of GIoULoss
), # PS: For a detailed explanation of each parameter configuration of the target, refer to the API documentation of DynamicFcosTarget
post_process=dict( # Post-process configurations of the detection model
type="FCOSMultiStrideFilter", # FCOSMultiStrideFilter is used by the post-processing
strides=[8, 16, 32, 64, 128], # Stride corresponding to the output feature
threshold=-2.944, # Threshold used in FilterModule OP
for_compile=False, # Whether the model need to support compile
score_threshold=0.05, # Score threshold is used for filtering boxes
iou_threshold=0.6, # IOU threshold for nms
max_shape=(512, 512), # Clamp the detection bbox according to max_shape
), # PS: For a detailed explanation of each parameter configuration of post-processing, refer to the API documentation of FCOSMultiStrideFilter
loss_cls=dict( # Loss function used by the cls branch
type="FocalLoss",
loss_name="cls",
num_classes=80 + 1,
alpha=0.25,
gamma=2.0,
loss_weight=1.0,
),
loss_centerness=dict( # Loss function used by the centerness branch
type="CrossEntropyLoss", loss_name="centerness", use_sigmoid=True
),
loss_reg=dict( # The loss function used by the reg branch
type="GIoULoss",
loss_name="reg",
loss_weight=1.0,
),
)
# Similar to the model definition, deploy_model is used for model compilation, so there is no loss.
# deploy_model is usually used in the int_infer stage.
deploy_model = dict(
type="FCOS",
backbone=dict(
type="efficientnet",
bn_kwargs=bn_kwargs,
model_type="b0",
num_classes=1000,
include_top=False,
activation="relu",
use_se_block=False,
),
neck=dict(
type="BiFPN",
in_strides=[2, 4, 8, 16, 32],
out_strides=[8, 16, 32, 64, 128],
stride2channels=dict({2: 16, 4: 24, 8: 40, 16: 112, 32: 320}),
out_channels=64,
num_outs=5,
stack=3,
start_level=2,
end_level=-1,
fpn_name="bifpn_sum",
),
head=dict(
type="FCOSHead",
num_classes=num_classes,
in_strides=[8, 16, 32, 64, 128],
out_strides=[8, 16, 32, 64, 128],
stride2channels=dict({8: 64, 16: 64, 32: 64, 64: 64, 128: 64}),
upscale_bbox_pred=False,
feat_channels=64,
stacked_convs=4,
int8_output=False,
int16_output=True,
dequant_output=False,
),
post_process=dict(
type="FCOSMultiStrideFilter",
strides=[8, 16, 32, 64, 128],
threshold=-2.944,
for_compile=True,
max_shape=(512, 512),
),
)
# Input used when compiling deploy_model
deploy_inputs = dict(img=torch.randn((1, 3, 512, 512)))
# Process of converting deploy_model from floating-point to quantized, used to verify whether the model can be compiled.
deploy_model_convert_pipeline = dict(
type="ModelConvertPipeline",
qat_mode="fuse_bn",
converters=[
dict(type="Float2QAT"), # Convert model from float to qat
dict(type="QAT2Quantize"), # Convert model from qat to quantized
],
)
# Loading process of the training dataset
data_loader = dict(
type=torch.utils.data.DataLoader, # Use torch's native DataLoader
dataset=dict( # Process of obtaining datasets
type="Coco", # Corresponding to coco's dataset obtaining interface
data_path="./tmp_data/mscoco/train_lmdb/", # Path to the dataset
transforms=[ # Data transformation process
dict(
type="Resize", # Resizing operation
img_scale=(512, 512), # Image size after resizing
ratio_range=(0.5, 2.0), # Image scaling range
keep_ratio=True, # Whether to maintain the aspect ratio during scaling
),
dict(type="RandomCrop", size=(512, 512)), # Random cropping operation
dict( # Padding operation
type="Pad",
divisor=512, # Length and width of the image after padding are multiples of 512
),
dict( # Random flipping operation
type="RandomFlip",
px=0.5, # Probability of flipping in the x direction
py=0, # Probability of flipping in the y direction
),
dict(type="AugmentHSV", hgain=0.015, sgain=0.7, vgain=0.4), # AugmentHSV operation
dict(
type="ToTensor", # Convert numpy to tensor
to_yuv=True, # Whether the image is converted to yuv format
),
dict( # Normalization operation, from [0,255] to [-1,1]
type="Normalize",
mean=128.0,
std=128.0,
),
],
),
sampler=dict(type=torch.utils.data.DistributedSampler), # Sampling method of the dataset in DDP training mode
batch_size=batch_size_per_gpu, # Batch_size of a single device
shuffle=True, # Whether to shuffle the data
num_workers=8, # Number of processes for data reading
pin_memory=True, # Whether to use pin_memory
collate_fn=hat.data.collates.collate_2d, # Method to collate and pack multiple images in batches
) # PS:For a detailed explanation of each parameter of DataLoader, refer to torch's official documentation
# Loading process of the validation dataset, similar to the training dataset
val_data_loader = dict(
type=torch.utils.data.DataLoader,
dataset=dict(
type="Coco",
data_path="./tmp_data/mscoco/val_lmdb/",
transforms=[
dict(
type="Resize",
img_scale=(512, 512),
keep_ratio=True,
),
dict(
type="Pad",
size=(512, 512),
),
dict(
type="ToTensor",
to_yuv=True,
),
dict(
type="Normalize",
mean=128.0,
std=128.0,
),
],
),
batch_size=batch_size_per_gpu,
shuffle=False,
num_workers=8,
pin_memory=True,
collate_fn=hat.data.collates.collate_2d,
)
# This function usually outputs the model output, to be specific, the loss part of the model outputs for the later gradient update
def loss_collector(outputs: dict):
losses = []
for _, loss in outputs.items():
losses.append(loss)
return losses
# This function updates the loss. Usually used to print loss during model training. It can be understood together with the following places that are called
def update_loss(metrics, batch, model_outs):
for metric in metrics:
metric.update(model_outs)
# Definiton of the function that prints the loss during the training
loss_show_update = dict(
type="MetricUpdater",
metric_update_func=update_loss,
step_log_freq=1,
epoch_log_freq=1,
log_prefix="loss_ " + task_name,
)
# Processing method of the training dataset for each iteration
batch_processor = dict(
type="MultiBatchProcessor",
need_grad_update=True, # Whether to perform gradient update
loss_collector=loss_collector, # Method to get the loss
)
# Processing method of the validation dataset for each iteration
val_batch_processor = dict(
type="MultiBatchProcessor",
need_grad_update=False, # Whether to perform gradient update
)
# Update method of the model metrics, here the metric is mAP
def update_metric(metrics, batch, model_outs):
for metric in metrics:
metric.update(model_outs)
# Update method of the validation metrics during model validation process
val_metric_updater = dict(
type="MetricUpdater",
metric_update_func=update_metric,
step_log_freq=500,
epoch_log_freq=1,
log_prefix="Validation " + task_name,
)
# Set the frequency as per which the training logs are printed
stat_callback = dict(
type="StatsMonitor",
log_freq=1,
)
# Trace the model and save the corresponding pt file
trace_callback = dict(
type="SaveTraced",
save_dir=ckpt_dir,
trace_inputs=deploy_inputs,
)
# Save the weights of the model
ckpt_callback = dict(
type="Checkpoint",
save_dir=ckpt_dir,
name_prefix=training_step + "-",
save_interval=1,
strict_match=True,
mode="max",
monitor_metric_key="mAP",
)
# Validate the model after training
val_callback = dict(
type="Validation",
data_loader=val_data_loader,
batch_processor=val_batch_processor,
callbacks=[val_metric_updater],
val_model=None,
init_with_train_model=False,
val_interval=1,
val_on_train_end=True,
)
# Settings for floating-point model training
float_trainer = dict(
type="distributed_data_parallel_trainer", # DDP training
model=model, # Model involved in the training
data_loader=data_loader, # Dataset involved in the training
optimizer=dict( # Optimizer settings
type=torch.optim.SGD,
params={"weight": dict(weight_decay=4e-5)},
lr=0.14,
momentum=0.937,
nesterov=True,
),
batch_processor=batch_processor, # How each iteration of the training dataset is processed
num_epochs=300, # Number of epochs for model training
device=None, # Device for model training
callbacks=[ # Callbacks that will be called during model training
stat_callback,
loss_show_update,
dict(type="ExponentialMovingAverage"),
dict(
type="CosLrUpdater",
warmup_len=2,
warmup_by="epoch",
step_log_interval=1,
),
val_callback,
ckpt_callback,
],
train_metrics=dict( # Metrics in the training process, used to print loss
type="LossShow",
),
sync_bn=True, # Whether to synchronize BN
val_metrics=dict( # Metrics in the validation process, used to print the metrics
type="COCODetectionMetric",
ann_file="./tmp_data/mscoco/instances_val2017.json",
),
)
calibration_data_loader = copy.deepcopy(data_loader) # Dataset involved in the calibration
calibration_data_loader.pop("sampler") # Calibration can only run on a single device, so no sample is required
calibration_batch_processor = copy.deepcopy(val_batch_processor) # How each iteration of the calibration dataset is processed
# Settings for Calibration model training
calibration_trainer = dict(
type="Calibrator",
model=model,
model_convert_pipeline=dict( # Convert the model from float to the one for calibration
type="ModelConvertPipeline",
qat_mode="fuse_bn",
converters=[
dict(
type="LoadCheckpoint", # Loads the float checkpoint before calibration
checkpoint_path=os.path.join(
ckpt_dir, "float-checkpoint-best.pth.tar"
),
),
dict(type="Float2Calibration"), # Converts the model from float to the one for calibration
],
),
data_loader=calibration_data_loader,
batch_processor=calibration_batch_processor,
num_steps=10, # Number of steps of calibration
device=None,
callbacks=[
stat_callback,
val_callback,
ckpt_callback,
],
val_metrics=dict(
type="COCODetectionMetric",
ann_file="./tmp_data/mscoco/instances_val2017.json",
),
log_interval=1,
)
# Settings for the QAT model training. For parameter meanings, refer to float_trainer
qat_trainer = dict(
type="distributed_data_parallel_trainer",
model=model,
model_convert_pipeline=dict( # Converts the model from float to qat
type="ModelConvertPipeline",
qat_mode="fuse_bn",
converters=[
dict(type="Float2QAT"), # Converts the model from a float model to qat
dict( # Loads Calibration checkpoint after conversion
type="LoadCheckpoint",
checkpoint_path=os.path.join( #
ckpt_dir, "calibration-checkpoint-best.pth.tar"
),
),
],
),
data_loader=data_loader,
optimizer=dict(
type=torch.optim.SGD,
params={"weight": dict(weight_decay=4e-5)},
lr=0.001, # Learning rate is usually set to one-tenth of the float training
momentum=0.9,
),
batch_processor=batch_processor,
num_epochs=10, # Number of training epochs of qat is usually much smaller than that of float
device=None,
callbacks=[
stat_callback,
loss_show_update,
dict(
type="StepDecayLrUpdater",
lr_decay_id=[4],
step_log_interval=500,
),
val_callback,
ckpt_callback,
],
train_metrics=dict(
type="LossShow",
),
val_metrics=dict(
type="COCODetectionMetric",
ann_file="./tmp_data/mscoco/instances_val2017.json",
),
)
# Settings for quantized model training. Usually, no training is performed at this stage
# Only the model parameters of quantize and the pt file of the model are saved by callbacks
int_infer_trainer = dict(
type="Trainer",
model=deploy_model, # Deploy_model
model_convert_pipeline=dict( # Converts model from float to quantized
type="ModelConvertPipeline",
qat_mode="fuse_bn",
converters=[
dict(type="Float2QAT"), # Converts model from float to qat
dict( # Loads qat checkpoint
type="LoadCheckpoint",
checkpoint_path=os.path.join(
ckpt_dir, "qat-checkpoint-best.pth.tar"
),
ignore_extra=True,
),
dict(type="QAT2Quantize"), # Converts model from qat to quantize
],
),
data_loader=None,
optimizer=None,
batch_processor=None,
num_epochs=0, # Epoch=0 to skip the training
device=None,
callbacks=[
ckpt_callback, # Saves quantized model parameters
trace_callback, # Saves quantized model pt file
],
)
# Model compilation settings
compile_dir = os.path.join(ckpt_dir, "compile")
compile_cfg = dict(
march=march,
name="fcos_effb0_test_model",
out_dir=compile_dir,
hbm=os.path.join(compile_dir, "model.hbm"),
layer_details=True,
input_source=["pyramid"],
opt="O3",
)
# Settings for float model predictor
float_predictor = dict(
type="Predictor", # Predictor
model=model, # Model involved in the predict
model_convert_pipeline=dict( # Loads float model checkpoint before predicting
type="ModelConvertPipeline",
converters=[
dict(
type="LoadCheckpoint",
checkpoint_path=os.path.join(
ckpt_dir, "float-checkpoint-best.pth.tar"
),
),
],
),
data_loader=[val_data_loader], # Dataset involved in the prediction
batch_processor=val_batch_processor, # How each iteration of the validation dataset is processed
device=None,
metrics=dict( # Metrics to print
type="COCODetectionMetric",
ann_file="./tmp_data/mscoco/instances_val2017.json",
),
callbacks=[
val_metric_updater,
],
log_interval=50,
)
# Settings for qat model predictor
qat_predictor = dict(
type="Predictor",
model=model,
model_convert_pipeline=dict( # Converts the model from float to qat
type="ModelConvertPipeline",
qat_mode="fuse_bn",
converters=[
dict(type="Float2QAT"), # Converts the model from float to qat
dict( # Loads qat checkpoint
type="LoadCheckpoint",
checkpoint_path=os.path.join(
ckpt_dir, "qat-checkpoint-best.pth.tar"
),
ignore_extra=True,
),
],
),
data_loader=[val_data_loader],
batch_processor=val_batch_processor,
device=None,
metrics=dict(
type="COCODetectionMetric",
ann_file="./tmp_data/mscoco/instances_val2017.json",
),
callbacks=[
val_metric_updater,
],
log_interval=50,
)
# Settings for quantized model predictor
int_infer_predictor = dict(
type="Predictor",
model=model,
model_convert_pipeline=dict( # Converts the model from float to quantized
type="ModelConvertPipeline",
qat_mode="fuse_bn",
converters=[
dict(type="Float2QAT"), # Converts the model from float to qat
dict( # Loads the qat checkpoint
type="LoadCheckpoint",
checkpoint_path=os.path.join(
ckpt_dir, "qat-checkpoint-best.pth.tar"
),
ignore_extra=True,
),
dict(type="QAT2Quantize"), # Converts the model from qat to quantized
],
),
data_loader=[val_data_loader],
batch_processor=val_batch_processor,
device=None,
metrics=dict(
type="COCODetectionMetric",
ann_file="./tmp_data/mscoco/instances_val2017.json",
),
callbacks=[
val_metric_updater,
],
log_interval=50,
)