import copy
import os
import shutil

import cv2
import horizon_plugin_pytorch as horizon
import numpy as np
import torch
from horizon_plugin_pytorch.dtype import qint8, qint16
from horizon_plugin_pytorch.march import March
from horizon_plugin_pytorch.quantization import get_qconfig, observer_v2
from horizon_plugin_pytorch.quantization.qconfig_setter import (
    ConvDtypeTemplate,
    MatmulDtypeTemplate, #通过名称或前缀配置Matmul算子单int16/双int16输入
    ModuleNameTemplate, ##通过module name指定dtype配置或量化阈值，包括激活/weight量化配置，固定scale配置;配置粒度支持全局、模型片段和算子等
    QconfigSetter,
    SensitivityTemplate,
)
from horizon_plugin_pytorch.quantization.qconfig_template import (  # noqa F401
    calibration_8bit_weight_16bit_act_qconfig_setter,
    default_calibration_qconfig_setter,
    default_qat_fixed_act_qconfig_setter,
    sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,
    sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
)
from PIL import Image
from torchvision.transforms.functional import pil_to_tensor

try:
    from torchvision.transforms.functional_tensor import resize
except ImportError:
    # torchvision 0.18
    from torchvision.transforms._functional_tensor import resize

from hat.data.collates.nusc_collates import collate_nuscenes
from hat.data.datasets.nuscenes_dataset import CLASSES
from hat.utils.config import ConfigVersion

VERSION = ConfigVersion.v2 #config版本号
training_step = os.environ.get("HAT_TRAINING_STEP", "float") #训练阶段名称

task_name = "bev_sparse_henet_tinym_nuscenes_one_batch_01"

num_classes = 10
batch_size_per_gpu = 2
dataloader_workers = 0
device_ids = [0]  # 1 node
ckpt_dir = "./tmp_models/%s" % task_name
cudnn_benchmark = False
seed = None ##是否设置随机种子
log_rank_zero_only = True #是否只在第0卡上打印log信息
bn_kwargs = {} 
march = March.NASH_E #表示模型最终用于部署到什么架构的计算平台上
convert_mode = "jit-strip"
qat_mode = "fuse_bn"

num_query = 900
query_align = 128

orig_shape = (3, 900, 1600)
resize_shape = (3, 396, 704)
data_shape = (3, 256, 704)
val_data_shape = (3, 256, 704)

bev_range = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0)
position_range = (-61.2, -61.2, -10.0, 61.2, 61.2, 10.0)
vt_input_hw = (16, 44)

data_rootdir = "/home/ws11126/UISEE/dataset/J6/"
meta_rootdir = "/home/ws11126/UISEE/dataset/J6"
anchor_file = "/home/ws11126/UISEE/J6/horizon_j6_open_explorer_v3.7.0-py310_20251215/samples/ai_toolchain/horizon_model_train_sample/scripts/nuscenes_kmeans900.npy"

num_epochs = 3
num_steps_per_epoch = int(28130 // (len(device_ids) * batch_size_per_gpu))##其中28130表示nuscenes训练集的样本总数，num_steps_per_epoch是每个训练轮次(epoch)需要执行的迭代步数
num_steps = num_steps_per_epoch * num_epochs
'''
1个step(步):取1个batch的样本-->前向传播-->计算损失-->反向传播-->更新参数；
1个epoch(轮):遍历全部训练样本所需的step总数。
'''
embed_dims = 256
num_groups = 8
num_levels = 1
num_classes = 10
drop_out = 0.1
num_single_frame_decoder = 0  # 1
num_decoder = 6
num_depth_layers = 3

num_anchors = 384
temp_anchors = 128

model = dict(
    type="SparseBEVOE",
    compiler_model=False,
    backbone=dict(
        type="HENet",
        in_channels=3,
        block_nums=[4, 3, 8, 6],##每个stage包含的block数量
        embed_dims=[64, 128, 192, 384], #每个stage使用的基础block类型
        attention_block_num=[0, 0, 0, 0],#每个stage中的attention_block数量，将用在stage的尾部
        mlp_ratios=[2, 2, 2, 3],
        mlp_ratio_attn=2,
        act_layer=["nn.GELU", "nn.GELU", "nn.GELU", "nn.GELU"], #每个stage使用的激活函数
        use_layer_scale=[True, True, True, True], #是否对residual分支进行可学习的缩放
        layer_scale_init_value=1e-5,
        num_classes=1000,
        include_top=False,
        extra_act=[False, False, False, False],
        final_expand_channel=0,#在网络尾部的pooling之前进行channel扩增数量,0代表不使用扩增
        feature_mix_channel=1024, #在分类head之前进行channel扩增的数量
        block_cls=["GroupDWCB", "GroupDWCB", "AltDWCB", "DWCB"],#
        down_cls=["S2DDown", "S2DDown", "S2DDown", "None"],
        patch_embed="origin",
    ),
    neck=dict(
        type="MMFPN",
        in_strides=[2, 4, 8, 16, 32],
        in_channels=[64, 64, 128, 192, 384],
        fix_out_channel=256,
        out_strides=[4, 8, 16, 32],
    ),
    depth_branch=dict(  # for auxiliary supervision only
        type="DenseDepthNetOE",
        embed_dims=embed_dims,
        num_depth_layers=num_depth_layers,
        loss_weight=0.2,
    ),
    head=dict(
        type="SparseBEVOEHead",
        enable_dn=True,
        level_index=[2],
        cls_threshold_to_reg=0.05,
        instance_bank=dict(
            type="MemoryBankOE",
            num_anchor=num_anchors,
            embed_dims=embed_dims,
            num_memory_instances=num_anchors,
            anchor=anchor_file,
            num_temp_instances=temp_anchors,
            confidence_decay=0.6,
        ),
        anchor_encoder=dict(
            type="SparseBEVOEEncoder",
            pos_embed_dims=128,
            size_embed_dims=32,
            yaw_embed_dims=32,
            vel_embed_dims=64,
            vel_dims=3,
        ),
        num_single_frame_decoder=num_single_frame_decoder,
        operation_order=[
            "deformable",
            "ffn",
            "norm",
            "refine",
        ]
        * num_single_frame_decoder
        + [
            "temp_interaction",
            "interaction",
            "norm",
            "deformable",
            "ffn",
            "norm",
            "refine",
        ]
        * (num_decoder - num_single_frame_decoder),
        ffn=dict(
            type="AsymmetricFFNOE",
            in_channels=embed_dims * 2,
            pre_norm=True,
            embed_dims=embed_dims,
            feedforward_channels=embed_dims * 4,
            num_fcs=2,
            ffn_drop=drop_out,
        ),
        deformable_model=dict(
            type="DeformableFeatureAggregationOE",
            embed_dims=embed_dims,
            num_groups=num_groups,
            num_levels=num_levels,
            num_cams=6,
            attn_drop=0.15,
            use_camera_embed=True,
            residual_mode="cat",
            kps_generator=dict(
                type="SparseBEVOEKeyPointsGenerator",
                num_pts=8,
            ),
        ),
        refine_layer=dict(
            type="SparseBEVOERefinementModule",
            embed_dims=embed_dims,
            num_cls=num_classes,
            refine_yaw=True,
        ),
        target=dict(
            type="SparseBEVOETarget",
            num_dn_groups=5,
            num_temp_dn_groups=3,
            dn_noise_scale=[2.0] * 3 + [0.5] * 7,
            max_dn_gt=32,
            add_neg_dn=True,
            cls_weight=2.0,
            box_weight=0.25,
            reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
            cls_wise_reg_weights={
                CLASSES.index("traffic_cone"): [
                    2.0,
                    2.0,
                    2.0,
                    1.0,
                    1.0,
                    1.0,
                    0.0,
                    0.0,
                    1.0,
                    1.0,
                ],
            },
        ),
        cls_allow_reverse=[CLASSES.index("barrier")],
        loss_cls=dict(
            type="FocalLoss",
            loss_name="cls",
            num_classes=num_classes + 1,
            gamma=2.0,
            alpha=0.25,
            loss_weight=2.0,
        ),
        loss_reg=dict(type="L1Loss", loss_weight=0.25),
        loss_cns=dict(type="CrossEntropyLoss", use_sigmoid=True),##cns计算目标的类别概率
        loss_yns=dict(type="GaussianFocalLoss"), #对yaw角的预测损失
        decoder=dict(type="SparseBEVOEDecoder"),
        reg_weights=[2.0] * 3 + [1.0] * 7,
    ),
)
# 和model的定义基本相同,deploy_model主要用于模型编译和上板,因此没有loss部分
deploy_model = dict(
    type="SparseBEVOE",
    compiler_model=True,
    backbone=dict(
        type="HENet",
        in_channels=3,
        block_nums=[4, 3, 8, 6],
        embed_dims=[64, 128, 192, 384],
        attention_block_num=[0, 0, 0, 0],
        mlp_ratios=[2, 2, 2, 3],
        mlp_ratio_attn=2,
        act_layer=["nn.GELU", "nn.GELU", "nn.GELU", "nn.GELU"],
        use_layer_scale=[True, True, True, True],
        layer_scale_init_value=1e-5,
        num_classes=1000,
        include_top=False,
        extra_act=[False, False, False, False],
        final_expand_channel=0,
        feature_mix_channel=1024,
        block_cls=["GroupDWCB", "GroupDWCB", "AltDWCB", "DWCB"],
        down_cls=["S2DDown", "S2DDown", "S2DDown", "None"],
        patch_embed="origin",
    ),
    neck=dict(
        type="MMFPN",
        in_strides=[2, 4, 8, 16, 32],
        in_channels=[64, 64, 128, 192, 384],
        fix_out_channel=256,
        out_strides=[4, 8, 16, 32],
    ),
    depth_branch=dict(  # for auxiliary supervision only
        type="DenseDepthNetOE",
        embed_dims=embed_dims,
        num_depth_layers=num_depth_layers,
        loss_weight=0.2,
    ),
    head=dict(
        type="SparseBEVOEHead",
        enable_dn=True,
        level_index=[2],
        cls_threshold_to_reg=0.05,
        instance_bank=dict(
            type="MemoryBankOE",
            num_anchor=num_anchors,
            embed_dims=embed_dims,
            num_memory_instances=num_anchors,
            anchor=anchor_file,
            num_temp_instances=temp_anchors,
            confidence_decay=0.6,
        ),
        anchor_encoder=dict(
            type="SparseBEVOEEncoder",
            pos_embed_dims=128,
            size_embed_dims=32,
            yaw_embed_dims=32,
            vel_embed_dims=64,
            vel_dims=3,
        ),
        num_single_frame_decoder=num_single_frame_decoder,
        operation_order=[
            "deformable",
            "ffn",
            "norm",
            "refine",
        ]
        * num_single_frame_decoder
        + [
            "temp_interaction",
            "interaction",
            "norm",
            "deformable",
            "ffn",
            "norm",
            "refine",
        ]
        * (num_decoder - num_single_frame_decoder),
        ffn=dict(
            type="AsymmetricFFNOE",
            in_channels=embed_dims * 2,
            pre_norm=True,
            embed_dims=embed_dims,
            feedforward_channels=embed_dims * 4,
            num_fcs=2,
            ffn_drop=drop_out,
        ),
        deformable_model=dict(
            type="DeformableFeatureAggregationOE",
            embed_dims=embed_dims,
            num_groups=num_groups,
            num_levels=num_levels,
            num_cams=6,
            attn_drop=0.0,
            use_camera_embed=True,
            residual_mode="cat",
            kps_generator=dict(
                type="SparseBEVOEKeyPointsGenerator",
                num_pts=8,
            ),
        ),
        refine_layer=dict(
            type="SparseBEVOERefinementModule",
            embed_dims=embed_dims,
            num_cls=num_classes,
            refine_yaw=True,
        ),
    ),
)


def get_deploy_input():
    inputs = {
        "img": torch.randn((6, 3, 256, 704)),
        "projection_mat": torch.randn((6, 4, 4)),
        "cached_anchor": torch.randn((1, num_anchors, 11)),
        "cached_feature": torch.randn((1, num_anchors, 256)),
    }
    return inputs


deploy_inputs = get_deploy_input()


def get_eval_trace_input():
    inputs = {
        "img": torch.randn((6, 3, 256, 704)),
        "projection_mat": torch.randn((6, 4, 4)),
        "cached_anchor": torch.randn((1, num_anchors, 11)),
        "cached_feature": torch.randn((1, num_anchors, 256)),
        "cached_confidence": torch.randn((1, num_anchors)),
        "mask": torch.ones((1)).bool(),
        "timestamp": torch.randn((1)),
        "lidar2global": torch.randn((1, 4, 4)),
        "lidar2img": torch.randn((6, 4, 4)),
    }
    return inputs


eval_trace_inputs = get_eval_trace_input()


def get_train_trace_input():
    inputs = {
        "img": torch.randn((6, 3, 256, 704)),
        "timestamp": torch.randn((1)),
        "lidar2global": torch.randn((1, 4, 4)),
        "lidar2img": torch.randn((6, 4, 4)),
        "lidar_bboxes_labels": torch.randn((1, 20, 10)),
        "instance_ids": torch.randn((1, 20)),
        "camera_intrinsic": torch.randn((6, 3, 3)),
        "points": torch.rand((1, 20000, 3)),
    }
    return inputs


train_trace_inputs = get_train_trace_input()

train_dataset = dict(
    type="NuscenesBevDataset",
    data_path=os.path.join(data_rootdir, "train_lmdb"),
    transforms=[
        dict(type="MultiViewsImgResize", scales=(0.40, 0.47)),
        dict(type="MultiViewsImgCrop", size=(256, 704), random=False),
        dict(type="MultiViewsImgFlip"),
        dict(type="MultiViewsImgRotate", rot=(-5.4, 5.4)),
        dict(type="BevBBoxRotation", rotation_3d_range=(-0.3925, 0.3925)), ##要加上这个，要不然掉点
        dict(type="MultiViewsPhotoMetricDistortion"),
        dict(
            type="MultiViewsGridMask",##增加随机掩码
            use_h=True,
            use_w=True,
            rotate=1,
            offset=False,
            ratio=0.5,
            mode=1,
            prob=0.7,
        ),
        dict(
            type="MultiViewsImgTransformWrapper",
            transforms=[
                dict(type="PILToTensor"),
                dict(type="BgrToYuv444", rgb_input=True),#转YUV，然后归一化
                dict(type="Normalize", mean=128, std=128),
            ],
        ),
    ],
    with_bev_bboxes=False,
    with_ego_bboxes=False,
    with_bev_mask=False,
    with_lidar_bboxes=True,
    need_lidar=True,
    num_split=2,
)

data_loader = dict(
    type=torch.utils.data.DataLoader,
    dataset=train_dataset,
    batch_sampler=dict(
        type="DistStreamBatchSampler",
        batch_size=batch_size_per_gpu,
        dataset=train_dataset,
        keep_consistent_seq_aug=True,
        skip_prob=0.0,
        sequence_flip_prob=0.0,
    ),
    num_workers=dataloader_workers,
    pin_memory=True,
    collate_fn=collate_nuscenes,
)

val_dataset = dict(
    type="NuscenesBevDataset",
    data_path=os.path.join(data_rootdir, "val_lmdb"),
    transforms=[
        dict(type="MultiViewsImgResize", size=(396, 704)),
        dict(type="MultiViewsImgCrop", size=(256, 704)),
        dict(
            type="MultiViewsImgTransformWrapper",
            transforms=[
                dict(type="PILToTensor"),
                dict(type="BgrToYuv444", rgb_input=True),
                dict(type="Normalize", mean=128, std=128),
            ],
        ),
    ],
    with_bev_bboxes=False,
    with_ego_bboxes=False,
    with_bev_mask=False,
    with_lidar_bboxes=True,
)
val_data_loader = dict(
    type=torch.utils.data.DataLoader,
    dataset=val_dataset,
    num_workers=dataloader_workers,
    pin_memory=True,
    collate_fn=collate_nuscenes,
    batch_size=1,
    shuffle=False,
)


def loss_collector(outputs: dict):
    losses = []
    for _, loss in outputs.items():
        losses.append(loss)
    return losses


def update_loss(metrics, batch, model_outs):
    for metric in metrics:
        metric.update(model_outs)


loss_show_update = dict(
    type="MetricUpdater",
    metric_update_func=update_loss,
    step_log_freq=50,
    epoch_log_freq=1,
    log_prefix="loss_" + task_name,
)

batch_processor = dict(
    type="MultiBatchProcessor",
    need_grad_update=True,
    loss_collector=loss_collector,
    grad_scaler=torch.cuda.amp.GradScaler(init_scale=32.0),
    enable_amp=True,
)

val_batch_processor = dict(
    type="MultiBatchProcessor",
    need_grad_update=False,
)


def update_metric(metrics, batch, model_outs):
    for metric in metrics:
        metric.update(batch, model_outs)


val_metric_updater = dict(
    type="MetricUpdater",
    metric_update_func=update_metric,
    step_log_freq=10000,
    epoch_log_freq=1,
    log_prefix="Validation " + task_name,
)

stat_callback = dict(
    type="StatsMonitor",
    log_freq=500,
    batch_size=batch_size_per_gpu,
)

grad_callback = dict(
    type="GradClip",
    max_norm=25,
    norm_type=2,
)

ckpt_callback = dict(
    type="Checkpoint",
    save_dir=ckpt_dir,
    name_prefix=training_step + "-",
    interval_by="step",
    save_interval=num_steps_per_epoch * 5,
    strict_match=False,
    mode="max",
)

val_callback = dict(
    type="Validation",
    data_loader=val_data_loader,
    batch_processor=val_batch_processor,
    callbacks=[val_metric_updater],
    val_model=None,
    init_with_train_model=False,
    val_interval=num_steps_per_epoch * 5,
    interval_by="step",
    val_on_train_end=True,
    log_interval=500,
)

val_nuscenes_metric = dict(
    type="NuscenesMetric",
    data_root=meta_rootdir,
    use_lidar=True,
    trans_lidar_dim=True,
    trans_lidar_rot=False,
    use_ddp=False,
    lidar_key="sensor2ego",
    version="v1.0-mini",
    save_prefix="./WORKSPACE/results" + task_name,
)

float_trainer = dict(
    type="distributed_data_parallel_trainer", ##设置DDP训练
    # type="Trainer", ##设置DDP训练
    model=model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(
                type="LoadCheckpoint",
                # checkpoint_path=os.path.join(
                #     "./tmp_models/henet_tinym_imagenet/float-checkpoint-best.pth.tar",  # noqa: E501
                # ),
                checkpoint_path=None,
                allow_miss=True,
                ignore_extra=True,
                verbose=True,
            ),
        ],
    ),
    data_loader=data_loader,
    optimizer=dict(
        type=torch.optim.AdamW,
        params={
            "backbone": dict(lr=3e-4),
        },
        eps=1e-8,
        betas=(0.9, 0.999),
        lr=3e-4,
        weight_decay=0.001,
    ), #设置优化器
    batch_processor=batch_processor, #每次迭代的处理方式
    num_steps=num_steps,
    stop_by="step",
    # num_epochs=100,
    device=None,
    callbacks=[
        stat_callback,
        loss_show_update,
        dict(type="ExponentialMovingAverage"),
        grad_callback,
        dict(
            type="CosineAnnealingLrUpdater",
            warmup_len=500,
            warmup_by="step",
            warmup_lr_ratio=1.0 / 3,
            step_log_interval=500,
            update_by="step",
            min_lr_ratio=1e-3,
        ),
        val_callback,
        ckpt_callback,
    ],
    train_metrics=dict(
        type="LossShow",
    ),
    sync_bn=True,
    val_metrics=[val_nuscenes_metric],##验证过程中的metric,主要用于打印指标
)

int16_max_point_muls = [
    f"head.layers.{layer}.point_mul" for layer in range(3, 39, 7)
]
int16_max_reciprocal_op = [
    f"head.layers.{layer}.reciprocal_op" for layer in range(3, 39, 7)
]

int16_point_cat = [
    f"head.layers.{layer}.point_cat" for layer in range(3, 39, 7)
]

float32_ops = [
    f"head.layers.{layer}.cls_layers" for layer in range(6, 35, 7)
] + [f"head.layers.{layer}.quality_layers" for layer in range(6, 35, 7)]


table_dir = "./tmp_models/bev_sparse_henet_tinym_nuscenes/quant_analysis/"
sensitive_setting_dict = {
    "output_quality_L1_sensitive_ops.pt": 20,
    "output_prediction_L1_sensitive_ops.pt": 20,
    "output_classification_L1_sensitive_ops.pt": 20,
}


def get_qconfig_setter(is_calibration=True):
    qconfig_setter = QconfigSetter(
        reference_qconfig=get_qconfig(
            observer=(
                observer_v2.MSEObserver
                if is_calibration is True
                else observer_v2.MinMaxObserver ##适用于qat阶段的input/output/weight和calibration阶段的weight
            ),##在这里插入观察节点
            fix_scale=False,#在QAT精度调优中(主要是图像分类任务中)，做完calibration后,把activation的scale固定住,不进行更新,即设置get_qconfig中fix_scale=True,QAT训练精度相比于不固定activation的scale的量化精度会更好
        ),
        templates=[
            # 1基本配置
            # 全局 feat int8，先设为qint8进行模型性能上限测试
            ModuleNameTemplate({"": qint8}),#全局feat int8,此时weight 默认为int16
            #通过module name指定dtype配置或量化阈值，配置dtype包括qint8,qint16,torch.float16,torch.float32等
            #
            MatmulDtypeTemplate(
                input_dtypes=[qint8, qint8],
                prefix=None,
            ),##通过名称或前缀配置Matmul算子单int8/双int8输入
            # 2 根据debug工具分析结果，将敏感的conv/Matmul进行配置
            # 将conv中敏感的weight输入配置为int16(按需配置)
            ConvDtypeTemplate(
                input_dtype=qint8,
                weight_dtype=qint8,
                prefix=None, #prefix=["backbone.conv1.conv1_3.conv"]
            ),#将conv和matmul类算子配置为全in8输入
            # # 将matmul中敏感的输入配置为int16
            # MatmulDtypeTemplate(
            #     input_dtypes=[qint16,qint8],
            #     prefix=["encoder.encoder.0.layers.0.self_attn.matmul"]
            # ),
            ModuleNameTemplate(
                {
                    m: {"dtype": qint16, "threshold": 1.1}
                    for m in int16_max_point_muls
                },  # quant int8，固定 scale 方式 1
            ), ##其中int16_max_point_muls表示从Head的第3层到第39层，间隔7的点积操作
            ModuleNameTemplate(
                {
                    m: {"dtype": qint16, "threshold": 11}
                    for m in int16_max_reciprocal_op
                },  # quant int8，固定 scale 方式 1
            ),
            ModuleNameTemplate(
                {
                    m: {"dtype": qint16, "threshold": 60}
                    for m in int16_point_cat
                }
            ),
            ModuleNameTemplate(
                {
                    "backbone.quant": {"dtype": qint8, "threshold": 1.0}, ##配置backbone.quant中算子qint8,固定sacle
                    "head.fc_before": {"dtype": qint16, "threshold": 5.0},
                    "head.fc_after": {"dtype": qint16, "threshold": 50.0},
                    "head.instance_bank.anchor_quant_stub": {
                        "dtype": qint16,
                        "threshold": 60,
                    },
                    "head.instance_bank.anchor_cat": {"output": qint16},
                    "head.layers.41.layers.11.scale_quant_stub": {
                        "output": qint16
                    },
                },
            ),
            ModuleNameTemplate(
                {m: None for m in float32_ops},  # quant int8，固定 scale 方式 1
                freeze=True,
            ),  # 全局 feat fp32
            SensitivityTemplate(
                [
                    ("head.fc_before", "output"),
                    ("head.fc_after", "output"),
                    ("head.instance_bank.anchor_quant_stub", "output"),
                    ("head.instance_bank.anchor_cat", "output"),
                    ("head.layers.41.layers.11.scale_quant_stub", "output"),
                    ("head.layers.41.layers.11.mul", "output"),
                    ("head.layers.41.add2", "output"),
                    ("head.layers.38.camera_encoder.0", "input"),
                    ("head.layers.31.camera_encoder.0", "input"),
                    ("head.layers.24.camera_encoder.0", "input"),
                    ("head.layers.17.camera_encoder.0", "input"),
                    ("head.layers.3.camera_encoder.0", "input"),
                    ("head.layers.10.camera_encoder.0", "input"),
                ],
                topk_or_ratio=1.0,
            ),
        ], #上述配置按照顺序依次生效
        enable_optimize=True, #是都采用默认的优化pass,默认配置为True
        save_dir="./qconfig_setting",
        custom_qconfig_mapping=None,
    )
    return qconfig_setter


calibration_qconfig_setter = get_qconfig_setter(True)
qat_qconfig_setter = get_qconfig_setter(False)
example_inputs = {"img": torch.randn((1,) + data_shape)}
# Note: The transforms of the dataset during calibration can be
# consistent with that during training or validation, or customized.
# Default used `val_batch_processor`.
calibration_data_loader = copy.deepcopy(data_loader)
calibration_data_loader["dataset"]["transforms"] = val_data_loader["dataset"]["transforms"]
calibration_batch_processor = copy.deepcopy(val_batch_processor)
calibration_step = 12

calibration_trainer = dict(
    type="Calibrator",
    model=model,
    skip_step=2,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "float-checkpoint-best.pth.tar"
                ),
                load_ema_model=True,
                allow_miss=True,
                ignore_extra=True,
                verbose=True,
            ),
            dict(
                type="Float2Calibration",
                convert_mode=convert_mode,
                example_inputs=eval_trace_inputs,
                qconfig_setter=calibration_qconfig_setter,
            ),
            dict(type="DisableSyncConverter"),
        ],
    ),
    data_loader=calibration_data_loader,
    batch_processor=calibration_batch_processor,
    num_steps=calibration_step,
    device=None,
    callbacks=[
        val_callback,
        ckpt_callback,
    ],
    log_interval=calibration_step / 10,
    val_metrics=[val_nuscenes_metric],
)

# qat_qconfig_setter = (default_qat_fixed_act_qconfig_setter,)
qat_model = copy.deepcopy(model)
qat_model["head"]["enable_dn"] = False
qat_trainer = dict(
    type="distributed_data_parallel_trainer",
    model=qat_model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(
                type="Float2QAT",
                convert_mode=convert_mode,
                example_inputs=train_trace_inputs,
                state="train",
                qconfig_setter=qat_qconfig_setter,
            ),
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "calibration-checkpoint-last.pth.tar"
                ),
                allow_miss=True,
                ignore_extra=True,
            ),
        ],
    ),
    data_loader=data_loader,
    optimizer=dict(
        type=torch.optim.AdamW,
        eps=1e-8,
        betas=(0.9, 0.999),
        params={
            "backbone": dict(lr=3e-6),
        },
        lr=3e-6,
        weight_decay=0.001,
    ),
    batch_processor=batch_processor,
    num_steps=num_steps * 0.1,
    stop_by="step",
    device=None,
    callbacks=[
        stat_callback,
        loss_show_update,
        dict(type="ExponentialMovingAverage", base_steps=50000),
        grad_callback,
        dict(
            type="StepDecayLrUpdater",
            lr_decay_id=[int(num_steps * 0.1 * 0.6)],
            step_log_interval=500,
        ),
        val_callback,
        ckpt_callback,
    ],
    train_metrics=dict(
        type="LossShow",
    ),
    val_metrics=[val_nuscenes_metric],
)


compile_dir = os.path.join(ckpt_dir, "compile")
compile_cfg = dict(
    march=march,
    name=task_name + "_model",
    out_dir=compile_dir,
    hbm=os.path.join(compile_dir, "spars4D_01.hbm"),
    layer_details=True,
    debug=True, ##通过设定为True，在编译的时候就会有layer_details生成
    input_source="pyramid, ddr, ddr, ddr, ddr, ddr",
    opt="O2",
    split_dim=dict(
        inputs={
            "0": [0, 6],
        }
    ),
)

# predictor
float_predictor = dict(
    type="Predictor",
    model=model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        converters=[
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(ckpt_dir, "float-checkpoint-best.pth.tar"),
                load_ema_model=True,
                allow_miss=True,
                ignore_extra=True,
            ),
        ],
    ),
    data_loader=[val_data_loader],
    batch_processor=val_batch_processor,
    device=None,
    metrics=[val_nuscenes_metric],
    callbacks=[
        val_metric_updater,
    ],
    log_interval=50,
)
calibration_checkpoint_path = os.path.join(
    ckpt_dir, "calibration-checkpoint-last.pth.tar"
)

calibration_predictor = dict(
    type="Predictor",
    model=model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(
                type="Float2Calibration",
                convert_mode=convert_mode,
                example_inputs=eval_trace_inputs,
                qconfig_setter=calibration_qconfig_setter,
            ),
            dict(
                type="LoadCheckpoint",
                checkpoint_path=calibration_checkpoint_path,
                ignore_extra=True,
                allow_miss=True,
                verbose=True,
            ),
        ],
    ),
    data_loader=[val_data_loader],
    batch_processor=val_batch_processor,
    device=None,
    metrics=[val_nuscenes_metric],
    callbacks=[
        val_metric_updater,
    ],
    log_interval=50,
)

qat_predictor = dict(
    type="Predictor",
    model=model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(
                type="Float2QAT",
                convert_mode=convert_mode,
                example_inputs=eval_trace_inputs,
                state="val",
                qconfig_setter=qat_qconfig_setter,
            ),
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "qat-checkpoint-best.pth.tar"
                ),
                load_ema_model=True,
                allow_miss=True,
                ignore_extra=True,
            ),
            dict(type="SetQuantRoundingMode"),
        ],
    ),
    data_loader=[val_data_loader],
    batch_processor=val_batch_processor,
    device=None,
    metrics=[val_nuscenes_metric],
    callbacks=[
        val_metric_updater,
    ],
    log_interval=50,
)

deploy_model_convert_pipeline = dict(
    type="ModelConvertPipeline",
    qat_mode="fuse_bn",
    converters=[
        dict(
            type="Float2QAT", ##模型由float变味qat
            # type="QAT2Quantize", #模型由qat变成quantize
            convert_mode=convert_mode,
            example_inputs=eval_trace_inputs,
            state="val",
            qconfig_setter=qat_qconfig_setter,
        ),
    ],
)

export_model_convert_pipeline = dict(
    type="ModelConvertPipeline",
    qat_mode="fuse_bn",
    converters=[
        dict(
            type="Float2QAT",
            convert_mode=convert_mode, ##convert_mode = "jit-strip"
            example_inputs=eval_trace_inputs,
            state="val",
            qconfig_setter=qat_qconfig_setter,
        ),
        dict(
            type="LoadCheckpoint",
            checkpoint_path=os.path.join(
                ckpt_dir, "qat-checkpoint-best.pth.tar"
            ),
            ignore_extra=True,
            allow_miss=True,
            verbose=True,
        ),
    ],
)

output_names = ["classification", "prediction", "quality", "feature"]
first_frame_input = {
    "cached_anchor": torch.zeros((1, num_anchors, 11)),
    "cached_feature": torch.zeros((1, num_anchors, 256)),
    "cached_confidence": torch.zeros((1, num_anchors)),
    "mask": torch.zeros((1)).bool(),
}

hbir_infer_model = dict(
    type="SparseBEVOEIrInfer",
    first_frame_input=first_frame_input,
    projection_mat_key="lidar2img",
    global_mat_key="lidar2global",
    ir_model=dict(
        type="HbirModule",
        # model_path=os.path.join(ckpt_dir, "quantized.bc"),
        # model_path="./hbir_output_nuscenes_full_260312/nv12_quantized_rgb.bc",
        model_path="./hbir_output_nuscenes_full_260312/nv12_quantized_yuv444.bc",
        # model_path="./hbir_output_nuscenes_full_260312/qat.quant.bc",
    ),
    decoder=dict(type="SparseBEVOEDecoder"),
    use_memory_bank=True,
    confidence_decay=0.6,
    num_temp_instances=temp_anchors,
    num_memory_instances=num_anchors,
)

int_infer_data_loader = copy.deepcopy(val_data_loader)
int_infer_data_loader["batch_size"] = 1


int_infer_predictor = dict(
    type="Predictor",
    model=hbir_infer_model,
    data_loader=int_infer_data_loader,
    batch_processor=val_batch_processor,
    device=None,
    metrics=[val_nuscenes_metric],
    callbacks=[
        val_metric_updater,
    ],
    log_interval=1,
)

infer_transforms = [
    dict(type="MultiViewsImgResize", size=(396, 704)),
    dict(type="MultiViewsImgCrop", size=(256, 704)),
    dict(
        type="MultiViewsImgTransformWrapper",
        transforms=[
            dict(type="PILToTensor"),
            dict(type="BgrToYuv444", rgb_input=True),
            dict(type="Normalize", mean=128, std=128),
        ],
    ),
]
align_bpu_data_loader = dict(
    type=torch.utils.data.DataLoader,
    dataset=dict(
        type="NuscenesFromImage",
        src_data_dir="./tmp_orig_data/nuscenes",
        version="v1.0-mini",
        split_name="val",
        transforms=infer_transforms,
        with_bev_bboxes=False,
        with_ego_bboxes=False,
        with_bev_mask=False,
        with_lidar_bboxes=True,
    ),
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=collate_nuscenes,
)

align_bpu_predictor = dict(
    type="Predictor",
    model=hbir_infer_model,
    data_loader=align_bpu_data_loader,
    metrics=[val_nuscenes_metric],
    callbacks=[
        val_metric_updater,
    ],
    log_interval=1,
    batch_processor=dict(type="BasicBatchProcessor", need_grad_update=False),
)
##这里使用quant_analysis类自动对比分析两个模型：通过对比定位到量化模型中异常算子或量化敏感op。
quant_analysis_solver = dict(
    type="QuantAnalysis",
    model=copy.deepcopy(model),
    device_id=2,
    dataloader=copy.deepcopy(calibration_data_loader),
    num_steps=100,
    baseline_model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        converters=[
            dict(
                type="LoadCheckpoint",
                checkpoint_path="configs/bev/tmp_models/bev_sparse_henet_tinym_nuscenes_full_260308/float-checkpoint-best.pth.tar",
                load_ema_model=True,
                allow_miss=True,
                ignore_extra=True,
            ),
        ],
    ),##分析基准模型
    analysis_model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(
                type="Float2Calibration",
                convert_mode=convert_mode,
                example_inputs=eval_trace_inputs,
                qconfig_setter=calibration_qconfig_setter,
            ),
            dict(
                type="LoadCheckpoint",
                checkpoint_path="configs/bev/tmp_models/bev_sparse_henet_tinym_nuscenes_full_260308/calibration-checkpoint-last.pth.tar",
                ignore_extra=True,
                allow_miss=True,
                verbose=True,
            ),
        ],
    ),
    analysis_model_type="fake_quant",
) ##分析float模型权重与calibration模型权重


def resize_homo(homo, scale):
    view = np.eye(4)
    view[0, 0] = scale[1]
    view[1, 1] = scale[0]
    homo = view @ homo
    return homo


def crop_homo(homo, offset):
    view = np.eye(4)
    view[0, 2] = -offset[0]
    view[1, 2] = -offset[1]
    homo = view @ homo
    return homo

##下述代码流程:读取(RGB)-->转RGB --->转PIL -->转Pytorch张量(C,H,W);
##缩放到指定尺寸--->增加批次维度(B,C,H,W);
#按[下侧截取、水平居中]裁剪到目标尺寸;
#返回处理后的张量+原始尺寸
def process_img(img_path, resize_size, crop_size):
    orig_img = cv2.imread(img_path)
    cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB, orig_img)##这里是rgb
    orig_img = Image.fromarray(orig_img)
    orig_img = pil_to_tensor(orig_img)##其中pil_to_tensor输出的张量数值是[0,255]
    resize_hw = (
        int(resize_size[0]),
        int(resize_size[1]),
    )

    orig_shape = (orig_img.shape[1], orig_img.shape[2])
    resized_img = resize(orig_img, resize_hw).unsqueeze(0)##其中orig_img=(900,1600),resize_hw=(396,704),
    top = int(resize_hw[0] - crop_size[0])
    left = int((resize_hw[1] - crop_size[1]) / 2)
    resized_img = resized_img[:, :, top:, left:]

    return resized_img, orig_shape ##resized_img(256,704)


def prepare_inputs(infer_inputs):
    dir_list = os.listdir(infer_inputs)
    dir_list.sort()
    input_datas = []
    for _, frame in enumerate(dir_list):#dir_list=['frame0','frame1','frame2']
        data = {}
        frame_path = os.path.join(infer_inputs, frame)
        file_list = list(os.listdir(frame_path))
        image_dir_list = list(filter(lambda x: x.endswith(".jpg"), file_list))
        image_dir_list.sort()

        data["imgs"] = [
            os.path.join(frame_path, tmpdir) for tmpdir in image_dir_list
        ] ##其中data["imgs"]中存有6张图片路径
        timestamp_path = os.path.join(frame_path, "timestamp.npy")
        data["timestamp"] = np.load(timestamp_path).reshape(1) #这里获得时间戳,1
        lidar2global_path = os.path.join(frame_path, "lidar2global.npy")
        data["lidar2global"] = np.load(lidar2global_path).reshape(1, 4, 4)##1*4*4
        lidar2img_path = os.path.join(frame_path, "lidar2img.npy")
        data["lidar2img"] = np.load(lidar2img_path)##6*4*4

        input_datas.append(data)
    return input_datas ##这里获得输入数据信息，其中包括imgs,timestamp,lidar2global,lidar2img

##将输入的图片，timestamp.npy，lidar2global.npy,lidar2img.npy进行处理
## 先准备输入，然后对输入中的图片进行处理
def process_inputs(data, transforms=None):
    resize_size = resize_shape[1:] ##resize_shape = (3, 396, 704)
    input_size = val_data_shape[1:] #val_data_shape = (3, 256, 704)

    orig_imgs = []##0={'name':0,'img':}
    for i, img_path in enumerate(data["imgs"]):
        img, orig_shape = process_img(img_path, resize_size, input_size) ##这里出来的是rgb
        orig_imgs.append({"name": i, "img": img})
    #首先根据图片的resize大小，对图片进行resize,然后对resize的图片转YUV
    input_imgs = []
    for orig_img in orig_imgs:
        input_img = horizon.nn.functional.bgr_to_yuv444(orig_img["img"], True) #将图片处理成yuv444格式,这是把bgr转yuv
        input_imgs.append(input_img)

    ori_images = []
    for orig_img in orig_imgs:
        ori_images.append(orig_img['img'])

    orig_images = torch.cat(ori_images)
    input_imgs = torch.cat(input_imgs)
    input_imgs = (input_imgs - 128.0) / 128.0 ##图片归一化,转完YUV以后对图像进行归一化

    homo = data["lidar2img"]#6*4*4,float64
    #对homo矩阵进行处理
    top = int(resize_size[0] - input_size[0])
    left = int((resize_size[1] - input_size[1]) / 2)

    scale = (resize_size[0] / orig_shape[0], resize_size[1] / orig_shape[1]) ##缩放比例
    homo = resize_homo(homo, scale) #将单应矩阵根据缩放比例进行缩放
    homo = crop_homo(homo, (left, top))##float64

    model_input = {
        "img": input_imgs,
        # "img": orig_images,##如果想python模型的图片输入数据为nv12，这里要直接采用orig_images,如果想对图片直接采用float32的输入，则采用input_imgs
        "lidar2img": torch.tensor(homo),
        "lidar2global": torch.tensor(data["lidar2global"]),
        "timestamp": torch.tensor(data["timestamp"]),
    }
    if transforms is not None:
        model_input = transforms(model_input)

    vis_inputs = {}
    vis_inputs["img"] = orig_imgs
    vis_inputs["meta"] = {"lidar2img": homo}


    return model_input, vis_inputs


def process_outputs(model_outs, viz_func, vis_inputs):

    outs = torch.cat(
        [
            model_outs[0]["bboxes"][..., :9].view(1, -1, 9),
            model_outs[0]["scores"].view(1, -1, 1),
            model_outs[0]["labels"].view(1, -1, 1),
        ],
        dim=-1,
    )
    outs[..., 3], outs[..., 4] = outs[..., 4], outs[..., 3]
    preds = {"lidar_det": outs}
    viz_func(vis_inputs["img"], preds, vis_inputs["meta"])
    return None


single_infer_dataset = copy.deepcopy(int_infer_data_loader["dataset"])
single_infer_dataset["transforms"] = None


def inputs_save_func(data, save_path):
    if os.path.isdir(save_path):
        shutil.rmtree(save_path)
    os.makedirs(save_path, exist_ok=True)
    for idx_, sample_data in enumerate(data):
        save_dir = os.path.join(save_path, f"frame{idx_}")
        os.makedirs(save_dir, exist_ok=True)
        for image_idx, (img_name, img_data) in enumerate(
            zip(sample_data["img_name"], sample_data["img"])
        ):
            save_name = f"img{image_idx}_{os.path.basename(img_name)}"
            img_data.save(os.path.join(save_dir, save_name), "JPEG")

        lidar2global_path = os.path.join(save_dir, "lidar2global.npy")
        np.save(lidar2global_path, np.array(sample_data["lidar2global"]))
        lidar2img_path = os.path.join(save_dir, "lidar2img.npy")
        np.save(lidar2img_path, np.array(sample_data["lidar2img"]))
        timestamp_path = os.path.join(save_dir, "timestamp.npy")
        np.save(timestamp_path, np.array(sample_data["timestamp"]))


infer_cfg = dict(
    model=hbir_infer_model,
    input_path=f"./configs/bev/tmp_models/{task_name}",
    gen_inputs_cfg=dict(
        dataset=single_infer_dataset,
        sample_idx=[0, 3],
        inputs_save_func=inputs_save_func,
    ),
    prepare_inputs=prepare_inputs,
    process_inputs=process_inputs,
    viz_func=dict(type="NuscenesViz", is_plot=True),
    process_outputs=process_outputs,
)

onnx_cfg = dict(
    model=deploy_model,
    inputs=eval_trace_inputs,
    # stage="qat",
    stage="float",
    model_convert_pipeline=float_predictor["model_convert_pipeline"],
    # model_convert_pipeline=dict(
    #     type="ModelConvertPipeline",
    #     qat_mode="fuse_bn",
    #     converters=[
    #         dict(
    #             type="Float2QAT",
    #             convert_mode=convert_mode,
    #             example_inputs=eval_trace_inputs,
    #             qconfig_setter=qat_qconfig_setter,
    #             state="val",
    #         ),
    #         dict(
    #             type="LoadCheckpoint",
    #             checkpoint_path=os.path.join(
    #                 ckpt_dir, "qat-checkpoint-best.pth.tar"
    #             ),
    #             ignore_extra=True,
    #             allow_miss=True,
    #             verbose=True,
    #         ),
    #     ],
    # ),
)

calops_cfg = dict(method="hook")
