Qucick Start
Basic Process
The basic process for using the Quantized Awareness Training Tool is as follows:

The following is an example of the MobileNetV2 model from torchvision to introduce you to each stage of the process.
We used the cifar-10 dataset instead of the ImageNet-1K dataset due to the speed of execution of the process display.
import os
import copy
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch import Tensor
from torch.quantization import DeQuantStub
from torchvision.datasets import CIFAR10
from torchvision.models.mobilenetv2 import MobileNetV2
from torch.utils import data
from typing import Optional, Callable, List, Tuple
from horizon_plugin_pytorch.functional import rgb2centered_yuv
import torch.quantization
from horizon_plugin_pytorch.march import March, set_march
from horizon_plugin_pytorch.quantization import (
QuantStub,
prepare_qat_fx,
set_fake_quantize,
FakeQuantState,
)
from horizon_plugin_pytorch.quantization.qconfig import (
default_calib_8bit_fake_quant_qconfig,
default_qat_8bit_fake_quant_qconfig,
default_qat_8bit_weight_32bit_out_fake_quant_qconfig,
default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
)
from hbdk4 import compiler as hb4
import logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name: str, fmt=":f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
def accuracy(output: Tensor, target: Tensor, topk=(1,)) -> List[Tensor]:
"""Computes the accuracy over the k top predictions for the specified
values of k
"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].float().sum()
res.append(correct_k.mul_(100.0 / batch_size))
return res
def evaluate(
model: nn.Module, data_loader: data.DataLoader, device: torch.device
) -> Tuple[AverageMeter, AverageMeter]:
top1 = AverageMeter("Acc@1", ":6.2f")
top5 = AverageMeter("Acc@5", ":6.2f")
with torch.no_grad():
for image, target in data_loader:
image, target = image.to(device), target.to(device)
output = model(image)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1, image.size(0))
top5.update(acc5, image.size(0))
print(".", end="", flush=True)
print()
return top1, top5
def train_one_epoch(
model: nn.Module,
criterion: Callable,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
data_loader: data.DataLoader,
device: torch.device,
) -> None:
top1 = AverageMeter("Acc@1", ":6.3f")
top5 = AverageMeter("Acc@5", ":6.3f")
avgloss = AverageMeter("Loss", ":1.5f")
model.to(device)
for image, target in data_loader:
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step()
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1, image.size(0))
top5.update(acc5, image.size(0))
avgloss.update(loss, image.size(0))
print(".", end="", flush=True)
print()
print(
"Full cifar-10 train set: Loss {:.3f} Acc@1"
" {:.3f} Acc@5 {:.3f}".format(avgloss.avg, top1.avg, top5.avg)
)
Getting Floating-point Model
First, the floating-point model is modified as necessary to support quantization-related operations. The operations necessary for model transformation are:
- Insert
QuantStub before the model inputs.
- Insert
DequantStub before the model inputs.
Attention is needed when remodeling the model:
- The inserted
QuantStub and DequantStub must be registered as submodules of the model, otherwise their quantized state will not be handled correctly.
- Multiple inputs can share
QuantStub only if the scale is the same, otherwise define a separate QuantStub for each input.
- If you need to specify the source of the data entered during board up as
“pyramid”, please manually set the scale parameter of the corresponding QuantStub to 1/128.
- It is also possible to use
torch.quantization.QuantStub, but only horizon_plugin_pytorch.quantization.QuantStub supports manually fixing the scale with the parameter.
The modified model can seamlessly load the parameters of the pre-modified model, so if there is an existing trained floating-point model, it can be loaded directly, otherwise you need to do floating-point training normally.
Attention
The input image data is typically in centered_yuv444 format when the model is on board, so the image needs to be converted to centered_yuv444 format when the model is trained (note the use of rgb2centered_yuv in the code below).
If it is not possible to convert to centered_yuv444 format for model training, please insert the appropriate color space conversion node on the input when the model is deployed. (Note that this method may result in lower model accuracy)
The example has fewer floating-point and QAT training epochs, just to illustrate the process of using the training tool, and the accuracy does not represent the best level of the model.
######################################################################
# The user can modify the following parameters as required.
# 1. Save paths for model ckpt and compiled outputs.
model_path = "model/mobilenetv2"
# 2. Download the dataset and save the path.
data_path = "data"
# 3. The batch_size used for training.
train_batch_size = 256
# 4. The batch_size used for prediction.
eval_batch_size = 256
# 5. Number of epochs trained.
epoch_num = 10
# 6. The device used for model saving and performing calculations.
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
######################################################################
# To prepare the dataset, note the use of rgb2centered_yuv in collate_fn.
def prepare_data_loaders(
data_path: str, train_batch_size: int, eval_batch_size: int
) -> Tuple[data.DataLoader, data.DataLoader]:
normalize = transforms.Normalize(mean=0.0, std=128.0)
def collate_fn(batch):
batched_img = torch.stack(
[
torch.from_numpy(np.array(example[0], np.uint8, copy=True))
for example in batch
]
).permute(0, 3, 1, 2)
batched_target = torch.tensor([example[1] for example in batch])
batched_img = rgb2centered_yuv(batched_img)
batched_img = normalize(batched_img.float())
return batched_img, batched_target
train_dataset = CIFAR10(
data_path,
True,
transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandAugment(),
]
),
download=True,
)
eval_dataset = CIFAR10(
data_path,
False,
download=True,
)
train_data_loader = data.DataLoader(
train_dataset,
batch_size=train_batch_size,
sampler=data.RandomSampler(train_dataset),
num_workers=8,
collate_fn=collate_fn,
pin_memory=True,
)
eval_data_loader = data.DataLoader(
eval_dataset,
batch_size=eval_batch_size,
sampler=data.SequentialSampler(eval_dataset),
num_workers=8,
collate_fn=collate_fn,
pin_memory=True,
)
return train_data_loader, eval_data_loader
# Make the necessary modifications to the floating point model.
class FxQATReadyMobileNetV2(MobileNetV2):
def __init__(
self,
num_classes: int = 10,
width_mult: float = 1.0,
inverted_residual_setting: Optional[List[List[int]]] = None,
round_nearest: int = 8,
):
super().__init__(
num_classes, width_mult, inverted_residual_setting, round_nearest
)
self.quant = QuantStub(scale=1 / 128)
self.dequant = DeQuantStub()
def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
x = super().forward(x)
x = self.dequant(x)
return x
if not os.path.exists(model_path):
os.makedirs(model_path, exist_ok=True)
# Floating-point model initialization.
float_model = FxQATReadyMobileNetV2()
# Prepare the dataset
train_data_loader, eval_data_loader = prepare_data_loaders(
data_path, train_batch_size, eval_batch_size
)
# Since the last layer of the model is inconsistent with the pre-trained model, a floating point finetune is required.
optimizer = torch.optim.Adam(
float_model.parameters(), lr=0.001, weight_decay=1e-3
)
best_acc = 0
for nepoch in range(epoch_num):
float_model.train()
train_one_epoch(
float_model,
nn.CrossEntropyLoss(),
optimizer,
None,
train_data_loader,
device,
)
# Floating-point Precision Test.
float_model.eval()
top1, top5 = evaluate(float_model, eval_data_loader, device)
print(
"Float Epoch {}: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
nepoch, top1.avg, top5.avg
)
)
if top1.avg > best_acc:
best_acc = top1.avg
# Save optimal floating-point model parameters.
torch.save(
float_model.state_dict(),
os.path.join(model_path, "float-checkpoint.ckpt"),
)
Files already downloaded and verified
Files already downloaded and verified
....................................................................................................................................................................................................
Full cifar-10 train set: Loss 2.113 Acc@1 20.826 Acc@5 71.182
........................................
Float Epoch 0: evaluation Acc@1 33.140 Acc@5 85.710
...
....................................................................................................................................................................................................
Full cifar-10 train set: Loss 1.167 Acc@1 58.864 Acc@5 94.682
........................................
Float Epoch 9: evaluation Acc@1 64.490 Acc@5 96.400
Calibration
After the model is transformed and the floating point training is completed, Calibration can be performed. This process is done by inserting Observer in the model and counting the distribution of data in each place during the forward process, so as to calculate a reasonable quantization parameter:
- For part of the model, the accuracy can be achieved by Calibration only, without the need for the more time-consuming quantized perception training.
- Even if the model cannot meet the accuracy requirements after quantization calibration, this process can reduce the difficulty of subsequent quantization awareness training, shorten the training time, and improve the final training accuracy.
######################################################################
# The user can modify the following parameters as required.
# 1. The batch_size used for Calibration.
calib_batch_size = 256
# 2. The batch_size used for Validation.
eval_batch_size = 256
# 3. The amount of data used by Calibration, configured to inf to use all data.
num_examples = float("inf")
# 4. Code name of the target hardware platform.
march = March.NASH
# 5. Example input for model tracing and export hbir.
example_input = torch.rand(1, 3, 32, 32, device=device)
######################################################################
# Before model transformation, the hardware platform on which the model will be executed must be set up.
set_march(march)
# The output model will share the attributes of the input model, so in order not to affect the subsequent use of float_model,
# deepcopy is performed here.
calib_model = copy.deepcopy(float_model)
calib_model.qconfig = default_calib_8bit_fake_quant_qconfig
calib_model.classifier.qconfig = (
default_calib_8bit_weight_32bit_out_fake_quant_qconfig
)
# Transform the model into the Calibration state to characterize the numerical distribution of the data at each location statistically.
calib_model = prepare(calib_model, example_inputs=example_input)
# Prepare the dataset.
calib_data_loader, eval_data_loader = prepare_data_loaders(
data_path, calib_batch_size, eval_batch_size
)
# Perform Calibration process (no backward required).
# Note the control of the model state here, the model needs to be in the eval state for the behavior of Bn to match the requirements.
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
with torch.no_grad():
cnt = 0
for image, target in calib_data_loader:
image, target = image.to(device), target.to(device)
calib_model(image)
print(".", end="", flush=True)
cnt += image.size(0)
if cnt >= num_examples:
break
print()
# Test pseudo-quantization accuracy.
# Note the control of the model state here.
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
top1, top5 = evaluate(
calib_model,
eval_data_loader,
device,
)
print(
"Calibration: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
top1.avg, top5.avg
)
)
# Saving Calibration Model Parameters.
torch.save(
calib_model.state_dict(),
os.path.join(model_path, "calib-checkpoint.ckpt"),
)
2024-06-11 14:16:18,510 INFO: Begin check qat model...
2024-06-11 14:16:18,834 INFO: All fusable modules are fused in model!
2024-06-11 14:16:18,834 INFO: All modules in the model run exactly once.
2024-06-11 14:16:18,835 WARNING: Please check these modules qconfig if expected:
+---------------+---------------------------------------------------------+-----------------------------------------+
| module name | module type | msg |
|---------------+---------------------------------------------------------+-----------------------------------------|
| quant | <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> | Fixed scale 0.0078125 |
| classifier.1 | <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> | activation is None. Maybe output layer? |
+---------------+---------------------------------------------------------+-----------------------------------------+
2024-06-11 14:16:18,848 INFO: Check full result in ./model_check_result.txt
2024-06-11 14:16:18,848 INFO: End check
Files already downloaded and verified
Files already downloaded and verified
....................................................................................................................................................................................................
........................................
Calibration: evaluation Acc@1 64.280 Acc@5 96.240
If the quantization accuracy of the model after Calibration meets the requirements, the model_deploy step can be carried out directly, otherwise the quantization_awareness_training needs to be carried out to further improve the accuracy.
Quantization Awareness Training
The quantization awareness training makes the model aware of the impact of quantization during the training process by inserting pseudo-quantization nodes in the model, in this case fine-tuning the model parameters in order to improve the accuracy after quantization.
######################################################################
# The user can modify the following parameters as required.
# 1. The batch_size used for training.
train_batch_size = 256
# 2. The batch_size used for Validation.
eval_batch_size = 256
# 3. Number of epochs trained.
epoch_num = 3
######################################################################
# Prepare the dataset.
train_data_loader, eval_data_loader = prepare_data_loaders(
data_path, train_batch_size, eval_batch_size
)
qat_model = copy.deepcopy(float_model)
qat_model.qconfig = default_qat_8bit_fake_quant_qconfig
qat_model.classifier.qconfig = (
default_qat_8bit_weight_32bit_out_fake_quant_qconfig
)
# Convert the model to QAT state.
qat_model = prepare(qat_model, example_inputs=example_input)
# Load Quantization Parameters in Calibration Models.
qat_model.load_state_dict(calib_model.state_dict())
# Conduct quantized awareness training.
# As a filetune process, quantized awareness training generally requires a small learning rate to be set.
optimizer = torch.optim.Adam(
qat_model.parameters(), lr=1e-3, weight_decay=1e-4
)
best_acc = 0
for nepoch in range(epoch_num):
# Note the method of controlling the training state of the QAT model here.
qat_model.train()
set_fake_quantize(qat_model, FakeQuantState.QAT)
train_one_epoch(
qat_model,
nn.CrossEntropyLoss(),
optimizer,
None,
train_data_loader,
device,
)
# Note the method of controlling the eval state of the QAT model here.
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
top1, top5 = evaluate(
qat_model,
eval_data_loader,
device,
)
print(
"QAT Epoch {}: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
nepoch, top1.avg, top5.avg
)
)
if top1.avg > best_acc:
best_acc = top1.avg
torch.save(
qat_model.state_dict(),
os.path.join(model_path, "qat-checkpoint.ckpt"),
)
Files already downloaded and verified
Files already downloaded and verified
2024-06-11 14:20:45,090 INFO: Begin check qat model...
2024-06-11 14:20:45,236 INFO: All fusable modules are fused in model!
2024-06-11 14:20:45,236 INFO: All modules in the model run exactly once.
2024-06-11 14:20:45,237 WARNING: Please check these modules qconfig if expected:
+---------------+---------------------------------------------------------+-----------------------------------------+
| module name | module type | msg |
|---------------+---------------------------------------------------------+-----------------------------------------|
| quant | <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> | Fixed scale 0.0078125 |
| classifier.1 | <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> | activation is None. Maybe output layer? |
+---------------+---------------------------------------------------------+-----------------------------------------+
2024-06-11 14:20:45,249 INFO: Check full result in ./model_check_result.txt
2024-06-11 14:20:45,249 INFO: End check
2024-06-11 14:20:45,687 WARNING: fast training is experimental
....................................................................................................................................................................................................
Full cifar-10 train set: Loss 1.279 Acc@1 55.026 Acc@5 93.572
........................................
QAT Epoch 0: evaluation Acc@1 62.830 Acc@5 95.950
....................................................................................................................................................................................................
Full cifar-10 train set: Loss 1.151 Acc@1 59.474 Acc@5 94.822
........................................
QAT Epoch 1: evaluation Acc@1 65.940 Acc@5 96.520
....................................................................................................................................................................................................
Full cifar-10 train set: Loss 1.102 Acc@1 61.114 Acc@5 95.546
........................................
QAT Epoch 2: evaluation Acc@1 66.340 Acc@5 96.940
Model Deployment
Once the pseudo-quantization accuracy is up to standard, the processes related to model deployment can be executed.
Export Hbir Model
Model deployment first requires exporting the pseudo-quantization model as a Hbir model.
Attention
- The batch_size of the example_input used in model export determines the batch_size for model simulation and model uploading, if you need to use different batch_size for simulation and uploading, please use different data to export hbir model separately.
- It is also possible to skip the actual calibration and training process in Calibration and Quantization Awareness Training and go directly to the model deployment process first to ensure that there are no operations in the model that cannot be exported or compiled.
######################################################################
# The user can modify the following parameters as required.
# 1. Which model to use as input for the process, you can choose either calib_model or qat_model.
base_model = calib_model
######################################################################
from horizon_plugin_pytorch.quantization.hbdk4 import export
hbir_qat_model = export(base_model, (example_input,))
2024-06-11 14:24:16,708 INFO: Model ret: Tensor(shape=(1, 10), dtype=torch.float32, device=cuda:0)
Convert to Fixed-point Model
Once the pseudo-quantization accuracy is up to standard, the model can be converted to a fixed-point model. The results of the fixed-point model are generally considered to be identical to those of the compiled model.
Attention
- Hbir models support only a single
Tensor or Tuple[Tensor] as input, and only Tuple[Tensor] as output.
- It is not possible to achieve complete numerical agreement between the fixed-point model and the pseudo-quantization model, so please take the accuracy of the fixed-point model as the standard. If the fixed-point accuracy is not up to standard, you need to continue the quantized awareness training.
# Transform the model to a fixed-point state, note that the march here needs to be distinguished from nash-e/m/p.
# Note that the perf_output_dir parameter here must be the same as in the subsequent internal_compile.
hbir_quantized_model = hb4.convert(
hbir_qat_model,
March.NASH_E,
perf_output_dir=os.path.join(model_path, "perf_out"),
)
# Dataloader for test accuracy of hbir model. Please note that the batch size
# should be same as the example input when exporting hbir.
_, eval_hbir_data_loader = prepare_data_loaders(
data_path, train_batch_size, 1
)
def evaluate_hbir(
model: hb4.Module, data_loader: data.DataLoader
) -> Tuple[AverageMeter, AverageMeter]:
top1 = AverageMeter("Acc@1", ":6.2f")
top5 = AverageMeter("Acc@5", ":6.2f")
for image, target in data_loader:
image, target = image.cpu(), target.cpu()
output = model.functions[0](image)[0]
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1, image.size(0))
top5.update(acc5, image.size(0))
return top1, top5
# Test the accuracy of fixed-point models.
top1, top5 = evaluate_hbir(
hbir_quantized_model,
eval_hbir_data_loader,
)
print(
"Quantized model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
top1.avg, top5.avg
)
)
Files already downloaded and verified
Files already downloaded and verified
Quantized model: evaluation Acc@1 65.250 Acc@5 93.990
Model Compilation
After testing the accuracy of the fixed-point model and confirming that it meets the requirements, the model can be compiled, performance tested and visualized.
######################################################################
# The user can modify the following parameters as required.
# 1. The level of optimization enabled at compile time, the higher the level the faster the compiled model will be executed on the board, but the compilation process will be slower.
compile_opt = 1
######################################################################
# Model compile.
hb4.compile(
hbir_quantized_model,
os.path.join(model_path, "model.hbm"),
March.NASH_E,
opt=compile_opt,
)
# Model perf.
hb4.hbm_perf(
os.path.join(model_path, "model.hbm"),
March.NASH_E,
output_dir=model_path,
)
[2024-06-11 14:26:34.809] [warning] Performance information does not include operators in the model running on the CPU
[2024-06-11 14:26:34.809] [warning] Invalid debug data size, we will not use debug data.
[2024-06-11 14:26:34.809] [warning] Invalid debug data size, we will not use debug data.
FPS=11624.299999999999, latency = 86 us, DDR = 2576640 bytes (see model/mobilenetv2/forward.html)
HBDK hbm perf SUCCESS
[06h:26m:34s:692209us INFO hbrt4_loader::parsing] pid:107773 tid:107773 hbrt4_loader/src/parsing.rs:31: Load hbm header from file; filename="model/mobilenetv2/model.hbm"
[06h:26m:34s:694516us INFO hbrt4_log::logger] pid:107773 tid:107773 hbrt4_log/src/logger.rs:388: Logger of HBRT4 initialized, version = 4.0.21.post0.dev202406040731+3f96886
[06h:26m:34s:694525us INFO hbrt4_loader::parsing] pid:107773 tid:107773 hbrt4_loader/src/parsing.rs:53: Load hbm from file; filename="model/mobilenetv2/model.hbm"
0
# Model Visualization.
hb4.visualize(hbir_quantized_model)