import os
import torch
import torch.nn as nn
import torchvision
from hbdk4.compiler.torch import export
from horizon_plugin_pytorch import set_march, March
from hbdk4.compiler import load, convert, compile, save, hbm_perf
from hbdk4.compiler.hbm_tools import hbm_extract_desc
import time

import logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

##这段代码先实现不对模型进行nv12的转换，只是将6 batch进行拆分，然后输入输出的数据类型都不进行改变，再次验证模型是否对齐
def remove_op_by_ioname(func, io_name=None):
    for loc in  func.outputs + func.inputs:
        if not loc.is_removable[0]:
            if io_name == loc.name:
                raise ValueError(f"Failed when deleting {io_name} ,which id unremovable")
            continue
        attached_op = loc.get_attached_op[0]
        removed = None
        output_name = attached_op.outputs[0].name
        input_name = attached_op.inputs[0].name

        if io_name in [output_name, input_name]:
            removed, diagnostic = loc.remove_attached_op()
        if removed is True:
            print(f"Remove node {io_name} successfully",flush=True)
        if removed is False:
            raise ValueError(
                f"Failed when deleting {attached_op.name} operator,"
                f"error: {diagnostic}")

start = time.time()

#qat.bc
bc_model_path = "hbir_output_nuscenes_full_260312/qat.bc"

# Load hbir model
qat_model = load(bc_model_path)##这里下载原始的qat.bc文件
func = qat_model.functions[0]

#代码作用：把模型输入的6张拼接图像(6张拼成一个batch),先拆成独立单张-->每张单独做图像预处理(通道转换+归一化+格式转换)-->再送给模型推理
# 插入节点，其中我们的模型输入是:shape=(6,3,256,704),代表6张图像拼接在一起，dim=0,表示在第0维度(batch维度)切开,拆分后变成6个独立单张图像:
## img_0:(1,3,256,704),img_1:(1,3,256,704).....,之所以拆是因为地平线BPU不支持batch>1的图像预处理
nodes = func.inputs[0].insert_split(dim=0) # split-->0:img_0,1:img_1,2:img_2,3:img_3,4:img_4,5:img_5
##遍历这6张图，每张都执行一遍，保证每张图都有相同预处理
for node in reversed(nodes):
    node = node.insert_transpose(permutes=[0, 3, 1, 2])##进行通道维度转换，输入形状(1,3,256,704)--->[batch_size,height,weight,channel]
    node = node.insert_image_preprocess(mode="yuvbt601full2rgb", divisor=255, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])##作用:图像归一化+颜色空间转换：YUV-->RGB验收转换，除以255归一化到0~1，减去均值、除以方差，保证与模型训练时的预处理一致
    # node = node.insert_image_convert(mode="nv12")##转成地平线支持的nv12格式


# 转定点
advice_path = bc_model_path.replace("bc", "")
quantized_model  = convert(qat_model, March.NASH_M, advice=True, advice_path=advice_path, enable_vpu=True)
quant_model_path = bc_model_path.replace("bc", "quant_0410_6.bc")
save(quantized_model, quant_model_path)

# # 删除量化反量化节点
# quantized_model[0].remove_io_op(op_types = ["Quantize"])

# 模型编译及性能评测
hbm_model_path = bc_model_path.replace("qat.bc", "quant_0410_6.hbm")
jobs = min(os.cpu_count(), max(1, os.cpu_count() // 2 + 1))
# hbm_model = compile(quantized_model, hbm_model_path, "nash-m", opt=2, jobs=jobs, debug=True)
hbm_model = compile(quantized_model, hbm_model_path, "nash-m", opt=1, jobs=16)
hbm_perf(hbm_model_path, "./perf_result")

end = time.time()

print("Elapsed Time %d s ", end - start)
