import torch

#自定义模板为ModuleNameQconfigSetter,需要传入模块名和对应自定义的qconfig,一般用于设置fixed scale、配置linear weight int16等特殊需求
from horizon_plugin_pytorch.quantization.qconfig_template import (
    calibration_8bit_weight_16bit_act_qconfig_setter, 
    ModuleNameQconfigSetter,
)

from horizon_plugin_pytorch.quantization.qconfig_setter import (
    ModuleNameTemplate,
    MatmulDtypeTemplate,
    QconfigSetter,
    ConvDtypeTemplate,
    SensitivityTemplate,
)
from horizon_plugin_pytorch.dtype import qint8, qint16
from horizon_plugin_pytorch.quantization import get_qconfig, observer_v2

import copy
import os
from hat.data.collates.nusc_collates import collate_nuscenes
from hat.data.datasets.nuscenes_dataset import CLASSES

training_step = os.environ.get("HAT_TRAINING_STEP", "float")

batch_size_per_gpu = 2

dataloader_workers = 0
embed_dims = 256
num_depth_layers = 3
num_anchors = 384
temp_anchors = 128
num_single_frame_decoder = 0 
num_decoder = 6
drop_out = 0.1
num_groups = 8
num_levels = 1
num_classes = 10

device_ids = [0]

data_rootdir = "/home/ws11126/UISEE/dataset/J6/"
anchor_file = "./nuscenes_kmeans900.npy"

meta_rootdir = "/home/ws11126/UISEE/dataset/J6"

task_name = "sparse4D_resnet50_fpn_float_nuscenes"
ckpt_dir = "./tmp_models/%s" % task_name

convert_mode = "jit-strip"

cudnn_benchmark = False
seed = None

model = dict(
    type="Sparse4D",
    compiler_model=False,
    backbone=dict(
        type="ResNet50",
        num_classes=num_classes,
        bn_kwargs={},
        include_top=False,
        stride_change=True,
    ),
    neck=dict(
        type="FPN",
        in_strides=[2, 4, 8, 16, 32],
        in_channels=[64, 256, 512, 1024, 2048],
        # fix_out_channel=256,
        out_strides=[4, 8, 16, 32],
        out_channels = [256,256,256,256],
    ),
    depth_branch=dict(  # for auxiliary supervision only
        type="DenseDepthNet",
        embed_dims=embed_dims,
        num_depth_layers=num_depth_layers,
        loss_weight=0.2,
    ),
    head=dict(
        type="SparseBEVHead",
        enable_dn=True,
        level_index=[2],
        cls_threshold_to_reg=0.05,
        instance_bank=dict(
            type="MemoryBank",
            num_anchor=num_anchors,
            embed_dims=embed_dims,
            num_memory_instances=num_anchors,
            anchor=anchor_file,
            num_temp_instances=temp_anchors,
            confidence_decay=0.6,
        ),
        anchor_encoder=dict(
            type="SparseBEVEncoder",
            pos_embed_dims=128,
            size_embed_dims=32,
            yaw_embed_dims=32,
            vel_embed_dims=64,
            vel_dims=3,
        ),
        num_single_frame_decoder=num_single_frame_decoder,
        operation_order=[
            "deformable",
            "ffn",
            "norm",
            "refine",
        ]
        * num_single_frame_decoder
        + [
            "temp_interaction",
            "interaction",
            "norm",
            "deformable",
            "ffn",
            "norm",
            "refine",
        ]
        * (num_decoder - num_single_frame_decoder),
        ffn=dict(
            type="AsymmetricFFN",
            in_channels=embed_dims * 2,
            pre_norm=True,
            embed_dims=embed_dims,
            feedforward_channels=embed_dims * 4,
            num_fcs=2,
            ffn_drop=drop_out,
        ),
        deformable_model=dict(
            type="DeformableFeatureAggregation",
            embed_dims=embed_dims,
            num_groups=num_groups,
            num_levels=num_levels,
            num_cams=6,
            attn_drop=0.15,
            use_camera_embed=True,
            residual_mode="cat",
            kps_generator=dict(
                type="SparseBEVKeyPointsGenerator",
                num_pts=8,
            ),
        ),
        refine_layer=dict(
            type="SparseBEVRefinementModule",
            embed_dims=embed_dims,
            num_cls=num_classes,
            refine_yaw=True,
        ),
        target=dict(
            type="SparseBEVTarget",
            num_dn_groups=5,
            num_temp_dn_groups=3,
            dn_noise_scale=[2.0] * 3 + [0.5] * 7,
            max_dn_gt=32,
            add_neg_dn=True,
            cls_weight=2.0,
            box_weight=0.25,
            reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
            cls_wise_reg_weights={
                CLASSES.index("traffic_cone"): [
                    2.0,
                    2.0,
                    2.0,
                    1.0,
                    1.0,
                    1.0,
                    0.0,
                    0.0,
                    1.0,
                    1.0,
                ],
            },
        ),
        cls_allow_reverse=[CLASSES.index("barrier")],
        loss_cls=dict(
            type="FocalLoss",
            loss_name="cls",
            num_classes=num_classes + 1,
            gamma=2.0,
            alpha=0.25,
            loss_weight=2.0,
        ),
        loss_reg=dict(type="L1Loss", loss_weight=0.25),
        loss_cns=dict(type="CrossEntropyLoss", use_sigmoid=True),
        loss_yns=dict(type="GaussianFocalLoss"),
        decoder=dict(type="SparseBEVDecoder"),
        reg_weights=[2.0] * 3 + [1.0] * 7,
    ),
)

'''
# 这两个pt文件是通过debug工具得到的
table1 = torch.load("output_0-0_L1_sensitive_ops.pt")
table2 = torch.load("output_0-1_L1_sensitive_ops.pt")
'''

int16_max_point_muls = [f"head.layers.{layer}.point_mul" for layer in range(3, 39, 7)]

int16_max_reciprocal_op = [f"head.layers.{layer}.reciprocal_op" for layer in range(3, 39, 7)]

int16_point_cat = [f"head.layers.{layer}.point_cat" for layer in range(3, 39, 7)]

float32_ops = [
    f"head.layers.{layer}.cls_layers" for layer in range(6, 35, 7)
] + [f"head.layers.{layer}.quality_layers" for layer in range(6, 35, 7)]

def get_qconfig_setter(is_calibration=True):
    qconfig_setter = QconfigSetter(
        reference_qconfig=get_qconfig(
            observer=(
                observer_v2.MSEObserver
                if is_calibration is True
                else observer_v2.MinMaxObserver
            ), ##通过get_config接口构造Qconfig
            fix_scale=False,
        ),
        templates=[
            ModuleNameTemplate({"": qint8}),
            MatmulDtypeTemplate(
                input_dtypes=[qint8, qint8],
                prefix=None,
            ),
            ConvDtypeTemplate(
                input_dtype=qint8,
                weight_dtype=qint8,
                prefix=None,
            ),
            ModuleNameTemplate(
                {
                    m: {"dtype": qint16, "threshold": 1.1}
                    for m in int16_max_point_muls
                },  # quant int8，固定 scale 方式 1
            ),
            ModuleNameTemplate(
                {
                    m: {"dtype": qint16, "threshold": 11}
                    for m in int16_max_reciprocal_op
                },  # quant int8，固定 scale 方式 1
            ),
            ModuleNameTemplate(
                {
                    m: {"dtype": qint16, "threshold": 60}
                    for m in int16_point_cat
                }
            ),
            ModuleNameTemplate(
                {
                    "backbone.quant": {"dtype": qint8, "threshold": 1.0},
                    "head.fc_before": {"dtype": qint16, "threshold": 5.0},
                    "head.fc_after": {"dtype": qint16, "threshold": 50.0},
                    "head.instance_bank.anchor_quant_stub": {
                        "dtype": qint16,
                        "threshold": 60,
                    },
                    "head.instance_bank.anchor_cat": {"output": qint16},
                    "head.layers.41.layers.11.scale_quant_stub": {
                        "output": qint16
                    },
                },
            ),
            ModuleNameTemplate(
                {m: None for m in float32_ops},  # quant int8，固定 scale 方式 1
                freeze=True,
            ),  # 全局 feat fp32
            SensitivityTemplate(
                [
                    ("head.fc_before", "output"),
                    ("head.fc_after", "output"),
                    ("head.instance_bank.anchor_quant_stub", "output"),
                    ("head.instance_bank.anchor_cat", "output"),
                    ("head.layers.41.layers.11.scale_quant_stub", "output"),
                    ("head.layers.41.layers.11.mul", "output"),
                    ("head.layers.41.add2", "output"),
                    ("head.layers.38.camera_encoder.0", "input"),
                    ("head.layers.31.camera_encoder.0", "input"),
                    ("head.layers.24.camera_encoder.0", "input"),
                    ("head.layers.17.camera_encoder.0", "input"),
                    ("head.layers.3.camera_encoder.0", "input"),
                    ("head.layers.10.camera_encoder.0", "input"),
                ],
                topk_or_ratio=1.0,
            ),
        ],
        enable_optimize=True,
        save_dir="./qconfig_setting",
        custom_qconfig_mapping=None,
    )
    return qconfig_setter



train_dataset = dict(
    type="NuscenesBevDataset",
    data_path=os.path.join(data_rootdir, "train_lmdb"),
    transforms=[
        dict(type="MultiViewsImgResize", scales=(0.40, 0.47)),
        dict(type="MultiViewsImgCrop", size=(256, 704), random=False),
        dict(type="MultiViewsImgFlip"),
        dict(type="MultiViewsImgRotate", rot=(-5.4, 5.4)),
        dict(type="BevBBoxRotation", rotation_3d_range=(-0.3925, 0.3925)),##会解决一定的掉点问题
        dict(type="MultiViewsPhotoMetricDistortion"),
        dict(
            type="MultiViewsGridMask",
            use_h=True,
            use_w=True,
            rotate=1,
            offset=False,
            ratio=0.5,
            mode=1,
            prob=0.7,
        ),
        dict(
            type="MultiViewsImgTransformWrapper",
            transforms=[
                dict(type="PILToTensor"),
                dict(type="BgrToYuv444", rgb_input=True),
                dict(type="Normalize", mean=128, std=128),
            ],
        ),
    ],
    with_bev_bboxes=False,
    with_ego_bboxes=False,
    with_bev_mask=False,
    with_lidar_bboxes=True,
    need_lidar=True,
    num_split=2,
)

data_loader = dict(
    type=torch.utils.data.DataLoader,
    dataset=train_dataset,
    batch_sampler=dict(
        type="DistStreamBatchSampler",
        batch_size=batch_size_per_gpu,
        dataset=train_dataset,
        keep_consistent_seq_aug=True,
        skip_prob=0.0,
        sequence_flip_prob=0.0,
    ),
    num_workers=dataloader_workers,
    pin_memory=True,
    collate_fn=collate_nuscenes,
)

val_dataset = dict(
    type="NuscenesBevDataset",
    data_path=os.path.join(data_rootdir, "val_lmdb"),
    transforms=[
        dict(type="MultiViewsImgResize", size=(396, 704)),
        dict(type="MultiViewsImgCrop", size=(256, 704)),
        dict(
            type="MultiViewsImgTransformWrapper",
            transforms=[
                dict(type="PILToTensor"),
                dict(type="BgrToYuv444", rgb_input=True),
                dict(type="Normalize", mean=128, std=128),
            ],
        ),
    ],
    with_bev_bboxes=False,
    with_ego_bboxes=False,
    with_bev_mask=False,
    with_lidar_bboxes=True,
)

val_data_loader = dict(
    type=torch.utils.data.DataLoader,
    dataset=val_dataset,
    num_workers=dataloader_workers,
    pin_memory=True,
    collate_fn=collate_nuscenes,
    batch_size=1,
    shuffle=False,
)

val_batch_processor = dict(
    type="MultiBatchProcessor",
    need_grad_update=False,
)

def update_metric(metrics, batch, model_outs):
    for metric in metrics:
        metric.update(batch, model_outs)

val_metric_updater = dict(
    type="MetricUpdater",
    metric_update_func=update_metric,
    step_log_freq=10000,
    epoch_log_freq=1,
    log_prefix="Validation " + task_name,
)

num_steps_per_epoch = int(28130 // (len(device_ids) * batch_size_per_gpu))

val_callback = dict(
    type="Validation",
    data_loader=val_data_loader,
    batch_processor=val_batch_processor,
    callbacks=[val_metric_updater],
    val_model=None,
    init_with_train_model=False,
    val_interval=num_steps_per_epoch * 5,
    interval_by="step",
    val_on_train_end=True,
    log_interval=500,
)

calibration_qconfig_setter = get_qconfig_setter(True)

calibration_data_loader = copy.deepcopy(data_loader)
calibration_data_loader["dataset"]["transforms"] = val_data_loader["dataset"]["transforms"]

calibration_batch_processor = copy.deepcopy(val_batch_processor)
calibration_step = 12

def get_eval_trace_input():
    inputs = {
        "img": torch.randn((6, 3, 256, 704)),
        "projection_mat": torch.randn((6, 4, 4)),
        "cached_anchor": torch.randn((1, num_anchors, 11)),
        "cached_feature": torch.randn((1, num_anchors, 256)),
        "cached_confidence": torch.randn((1, num_anchors)),
        "mask": torch.ones((1)).bool(),
        "timestamp": torch.randn((1)),
        "lidar2global": torch.randn((1, 4, 4)),
        "lidar2img": torch.randn((6, 4, 4)),
    }
    return inputs


eval_trace_inputs = get_eval_trace_input()



ckpt_callback = dict(
    type="Checkpoint",
    save_dir=ckpt_dir,
    name_prefix=training_step + "-",
    interval_by="step",
    save_interval=num_steps_per_epoch * 5,
    strict_match=False,
    mode="max",
)

val_nuscenes_metric = dict(
    type="NuscenesMetric",
    data_root=meta_rootdir,
    use_lidar=True,
    trans_lidar_dim=True,
    trans_lidar_rot=False,
    use_ddp=False,
    lidar_key="sensor2ego",
    version="v1.0-mini",
    save_prefix="./WORKSPACE/results" + task_name,
)

calibration_trainer = dict(
    type="Calibrator",
    model=model,
    skip_step=2,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "float-checkpoint-last.pth.tar"
                ),
                load_ema_model=True,
                allow_miss=True,
                ignore_extra=True,
                verbose=True,
            ),
            dict(
                type="Float2Calibration",
                convert_mode=convert_mode,
                example_inputs=eval_trace_inputs,
                qconfig_setter=calibration_qconfig_setter,
            ),
            dict(type="DisableSyncConverter"),
        ],
    ),
    data_loader=calibration_data_loader,
    batch_processor=calibration_batch_processor,
    num_steps=calibration_step,
    device=None,
    callbacks=[
        val_callback,
        ckpt_callback,
    ],
    log_interval=calibration_step / 10,
    val_metrics=[val_nuscenes_metric],
)


calibration_checkpoint_path = os.path.join(
    ckpt_dir, "calibration-checkpoint-last.pth.tar"
)

calibration_predictor = dict(
    type="Predictor",
    model=model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        converters=[
            dict(
                type="Float2Calibration",
                convert_mode=convert_mode,
                example_inputs=eval_trace_inputs,
                qconfig_setter=calibration_qconfig_setter,
            ),
            dict(
                type="LoadCheckpoint",
                checkpoint_path=calibration_checkpoint_path,
                ignore_extra=True,
                allow_miss=True,
                verbose=True,
            ),
        ],
    ),
    data_loader=[val_data_loader],
    batch_processor=val_batch_processor,
    device=None,
    metrics=[val_nuscenes_metric],
    callbacks=[
        val_metric_updater,
    ],
    log_interval=50,
)