Training a model using the HAT algorithm toolkit is usually done with a single command:
In which, /PATH/TO/CONFIG is the config file for model training, which defines the model structure, dataset loading, and the entire training process.
This section introduces some fixed global keywords in the config file and their configuration descriptions, giving you an overview of the config file.
training_stage: Stages of model training, including float, qat, and int_infer.
device_ids: List of GPUs used for model training.
cudnn_benchmark: Whether to turn on CUDNN benchmark, usually defaults to True.
seed: Whether to set the random number seed. usually defaults to None.
log_rank_zero_only: Simplifies the log printing in multi-card training by outputting logs only on card 0. Usually defaults to True.
model: The structure of the model participating in the training process. type is the type of the model, e.g., Classifier, Segmentor, RetinaNet, etc., corresponding to a type of models in classification, segmentation, and detection, respectively. It will be built into a specific class in the process, and other parameters are all used to initialize this class.
deploy_model: The model structure that participates in the deploy process, mainly used for model compilation. Compared to model, in most cases you only need to set the loss function and the post-processing part to None.
deploy_inputs: Simulated inputs for the deploy procedure. Values do not matter here, just make sure the format meets the input requirements.
data_loader: Dataset loading process in the training phase. Its type is a specific class torch.utils.data.DataLoader, and other parameters are all used to initialize this class. You can also read the interface documents on the Pytorch website to learn these parameters. Here dataset means to read a specific dataset, e.g., ImageNet, MSCOCO, VOC, etc., and transforms means data enhancement operations added when reading the data.
val_data_loader: The dataset loading process in the phase of validating model performance. Different from data_loader, its data_path is different and the processes of transforms and sample are removed.
batch_processor: Operations performed by the model at each iteration stage during the training, including forward propagation, backward propagation, parameter update, etc. The batch_transforms parameter, if included, indicates that some data enhancement operations are performed on the GPU, which can greatly speed up the training.
val_batch_processor: The operations performed by the model at each iteration stage during the validation process, containing only forward propagation.
metric_updater: Metric updating method of the model during model training, which is used to verify whether the performance of the training model is improving. It is usually used together with train_metrics under float_trainer. train_metrics is the specific form of the metric, while metric_updater just provides an updating method.
val_metric_updater: Metric updating method of the trained model during the performance validation process, which is used to verify the final performance of the trained model. Similar to metric_updater, it is usually used together with val_metrics under float_trainer.
float_trainer: Configuration of the floating-point model training process. Its type is distributed_data_parallel_trainer, which means distributed training is supported. Other parameters define the model, dataset loading, optimizer, training epoch length, etc., where callbacks represents the operations performed in the training, such as model saving, learning rate update, precision validation, etc. It is a variable directly called by the tools/train.py file.
qat_trainer: Configuration for the QAT model training process. This parameter basically means the same as float_trainer. It is a variable directly called by the tools/train.py file.
int_infer_trainer: With no training processes included, it is only used to verify the accuracy of the fixed-point model. It is a variable directly called by the tools/train.py file.
compile_cfg: Compile-related configuration. out_dir is the output path of the compiled HBM file (deployment model).
The reason why these variables are called global keywords is that they are defined in almost every config file and basically carry the same functions. By reading this document, you can get a general idea of what a config file can do.
This section describes the configuration of the global keyword for the data type dict.
Global keywords of the dict type can be further divided into the following two types:
Those with type, such as model, data_loader, float_trainer, etc.
Those without type, such as compile_cfg, etc.
The difference is that a global keyword that contains type is essentially a class whose type value can be either a string variable or a specific class, and even if it is a string, it will eventually be built into a corresponding class at runtime. The values of all the keys in the dict except type are used to initialize this class. Similar to global keywords, these keys can be either a numeric value or a dict containing a type variable, such as the dataset property in data_loader, and the transforms property under this dataset.
For a global keyword without a type variable, it is a regular dict variable, and the code will get the corresponding values from its keys during runtime.
All provided configurations are guaranteed to work properly and reproduce the accuracy. If you need to modify the configuration due to the environment or training time, then you may need to change the training strategy as well. Directly modifying individual configurations in the config file sometimes may not lead to desired results.