import argparse
import os

import copy
import torch
import torch.nn as nn
import torch.quantization
import torchvision.transforms as transforms
from torch import Tensor
from torch.quantization import DeQuantStub
from torchvision.models.mobilenetv2 import (
    MobileNetV2,
)
from horizon_plugin_pytorch.march import March, set_march
from horizon_plugin_pytorch.quantization import (
    QuantStub,
    convert_fx,
    prepare_qat_fx,
)
from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,
    default_qat_8bit_fake_quant_qconfig,
    default_qat_8bit_weight_32bit_out_fake_quant_qconfig,
    default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
)
from typing import Optional, Callable, List, Tuple
from horizon_plugin_pytorch.nn.quantized import FloatFunctional
from common import *

# --- Ultralytics 工具 ---

# import ultralytics.utils.checks
# def dummy_check_font(*args, **kwargs): pass 
# ultralytics.utils.checks.check_font = dummy_check_font

# from ultralytics import YOLO

# from ultralytics.nn.modules.block import C2f, SPPF
# from ultralytics.nn.modules.conv import Concat
# from ultralytics.nn.modules.head import Detect, Segment

# from ultralytics.data import build_dataloader
# from ultralytics.utils.loss import v8SegmentationLoss
# from ultralytics.cfg import get_cfg
# from ultralytics.data.utils import check_det_dataset
from yolov8_utils.modules import *

# from ultralytics.models.yolo.segment import SegmentationTrainer
# from ultralytics.models.yolo.detect import DetectionTrainer as SegmentationTrainer
# Specify random seed for repeatable results
torch.manual_seed(191009)


##############################################################################
# At first, we do necessary modify to the MobilenetV2 model from torchvision.
# 1. Insert QuantStub before first layer and DequantStub after last layer.
# Operation replacement and fusion will be carried out automatically (^_^).
##############################################################################

class YOLOv8Backbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            X3Conv(3, 16, 3, 2),   # 0-P1/2
            X3Conv(16, 32, 3, 2),  # 1-P2/4
            X3C2f(32, 32, 1, True),         # 2
            X3Conv(32, 64, 3, 2),  # 3-P3/8
            X3C2f(64, 64, 2, True),       # 4
            X3Conv(64, 128, 3, 2),  # 5-P4/16
            X3C2f(128, 128, 2, True),
            X3Conv(128, 256, 3, 2), # 7-P5/32
            X3C2f(256, 256, 1, True),
            X3SPPF(256, 256, 5)     # 9
        ])
    
    def forward(self, x):
        outputs = []
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i in [4, 6, 9]:
                outputs.append(x)  # 保存各阶段特征
        return outputs  # [P4, P6, final]

class YOLOv8Head(nn.Module):
    def __init__(self, nc=9):
        super().__init__()

        self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.concat1 = Concat(1)
        self.c2f1 = X3C2f(384, 128, 1)

        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.concat2 = Concat(1)
        self.c2f2 = X3C2f(192, 64, 1)

        self.down1 = X3Conv(64, 64, 3, 2)
        self.concat3 = Concat(1)
        self.c2f3 = X3C2f(192, 128, 1)

        self.down2 = X3Conv(128, 128, 3, 2)
        self.concat4 = Concat(1)
        self.c2f4 = X3C2f(384, 256, 1)

        self.segment_head = X3Segment(nc, 32, 256, (64, 128, 256))
        self.quant = QuantStub(scale=1 / 128)
        self.dequant = DeQuantStub()

    def forward(self, features):
        # features = backbone 输出 [P4, P6, P9]
        P4, P6, P9 = features[-3], features[-2], features[-1]

        x = self.up1(P9)
        x = self.concat1([x, P6])
        x = self.c2f1(x)

        y = self.up2(x)
        y = self.concat2([y, P4])
        y_p3 = self.c2f2(y)

        y = self.down1(y_p3)
        y = self.concat3([y, x])
        y_p4 = self.c2f3(y)

        y = self.down2(y_p4)
        y = self.concat4([y, P9])
        y_p5 = self.c2f4(y)

        out = self.segment_head([y_p3, y_p4, y_p5])
        return out

class YOLOv8Model(nn.Module):
    def __init__(self, nc=9):
        super().__init__()
        self.backbone = YOLOv8Backbone()
        self.head = YOLOv8Head(nc=nc)
        self.quant = QuantStub(scale=1 / 128)
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        features = self.backbone(x)
        out = self.head(features)
        # t0 = out[0][0]
        # t1 = out[0][1]
        # t2 = out[0][2]
        # list_of_tensors = [t0, t1, t2]
        # qtensor_2_dequant = self.dequant(out[1])
        # qtensor_3_dequant = self.dequant(out[2])
        # out = (list_of_tensors, qtensor_2_dequant, qtensor_3_dequant)
        # out = self.dequant(out)
        return tuple(self.dequant(t) for t in out)

    # def forward(self, x):
    #     x = self.quant(x)
    #     features = self.backbone(x)
        
    #     # head 返回的是一个包含 10 个 Tensor 的 tuple
    #     # (Score_S, Box_S, Coef_S, Score_M, ..., Proto)
    #     out = self.head(features) 
        
    #     # 我们只需要简单地对这 10 个输出分别做去量化 (DeQuant)
    #     # 然后保持元组形式返回即可
    #     return tuple(self.dequant(t) for t in out)

# class FxQATReadyMobileNetV2(MobileNetV2):
#     def __init__(
#         self,
#         num_classes: int = 10,
#         width_mult: float = 0.5,
#         inverted_residual_setting: Optional[List[List[int]]] = None,
#         round_nearest: int = 8,
#     ):
#         super().__init__(
#             num_classes, width_mult, inverted_residual_setting, round_nearest
#         )
#         self.quant = QuantStub(scale=1 / 128)
#         self.dequant = DeQuantStub()

#     def forward(self, x: Tensor) -> Tensor:
#         x = self.quant(x)
#         x = super().forward(x)
#         x = self.dequant(x)

#         return x

def get_model_fx(
    stage: str,
    model_path: str,
    device: torch.device,
    march=March.BAYES,
) -> nn.Module:
    # 保持原有断言
    assert stage in ("float", "calib", "qat", "int_infer", "compile")
    model_kwargs = dict(nc=9)
    
    # 1. 实例化浮点模型
    float_model = YOLOv8Model(**model_kwargs).to(device)

    # 如果是浮点训练阶段，直接返回
    if stage == "float":
        return float_model

    # 2. 设置 BPU 架构 (转定点前必须设置)
    set_march(march)

    # 3. 准备模型结构 (根据不同阶段)
    # 注意：int_infer 阶段必须先恢复成 QAT 模型结构，才能加载 QAT 权重
    
    # ------------------ Calib 阶段 ------------------
    if stage == "calib":
        # 加载浮点权重用于校准
        float_ckpt_path = os.path.join(model_path, "float-checkpoint.ckpt")
        if os.path.exists(float_ckpt_path):
            float_model.load_state_dict(torch.load(float_ckpt_path, map_location=device))
            
        calib_model = prepare_qat_fx(
            float_model,
            {"": default_calib_8bit_fake_quant_qconfig},
        ).to(device)
        return calib_model

    # ------------------ QAT / Int_Infer 阶段 ------------------
    # 这两个阶段都需要构建 QAT 图结构
    
    # 先构建 QAT 结构的空模型 (权重稍后加载)
    qat_model = prepare_qat_fx(
        float_model,
        {"": default_qat_8bit_fake_quant_qconfig},
    ).to(device)

    if stage == "qat":
        # 如果是去训练 QAT，通常需要加载 calib 过的权重作为起点
        calib_ckpt_path = os.path.join(model_path, "calib-checkpoint.ckpt")
        if os.path.exists(calib_ckpt_path):
             # 注意：Calib 和 QAT 结构兼容，可以直接加载
            qat_model.load_state_dict(torch.load(calib_ckpt_path, map_location=device))
        return qat_model

    # ------------------ Int_Infer (转定点) ------------------
    if stage == "int_infer":
        # 判断传入的是文件还是目录
        if os.path.isfile(model_path):
            qat_ckpt_path = model_path
        else:
            qat_ckpt_path = os.path.join(model_path, "qat-checkpoint.ckpt")

        if not os.path.exists(qat_ckpt_path):
            raise FileNotFoundError(f"未找到 QAT 权重文件: {qat_ckpt_path}，无法转定点。")
        if not os.path.exists(qat_ckpt_path):
            raise FileNotFoundError(f"未找到 QAT 权重文件: {qat_ckpt_path}，无法转定点。")
        
        print(f"Loading QAT checkpoint from {qat_ckpt_path}...")
        # 严格加载 QAT 权重 (包含 weight, bias, scale, zero_point)
        qat_model.load_state_dict(torch.load(qat_ckpt_path, map_location=device), strict=True)
        
        # === 核心步骤：转定点 ===
        print("Converting QAT model to Quantized (Fixed-Point) model...")
        quantized_model = convert_fx(qat_model).to(device)
        
        return quantized_model

    return None




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # === 1. 定义参数 (确保这里包含了所有需要的参数) ===
    parser.add_argument("--stage", type=str, default="int_infer") 
    parser.add_argument("--model_path", type=str, default="/data6/liuziyi/yolov8_qat/qat_project/yolov8_multi_task")
    parser.add_argument("--data_yaml", type=str, default="/data6/liuziyi/yolov8_qat/qat_project/data.yaml")
    
    # 这里的参数必须定义，否则 args.train_batch_size 会报错
    parser.add_argument("--train_batch_size", type=int, default=36)
    parser.add_argument("--eval_batch_size", type=int, default=1)
    parser.add_argument("--calib_batch", type=int, default=1)
    parser.add_argument("--epoch_num", type=int, default=10)
    parser.add_argument("--device_id", type=int, default=0)
    parser.add_argument("--march", type=str, default='March.BERNOULLI2')
    parser.add_argument("--imgsz", type=int, default=640)
    parser.add_argument("--workers", type=int, default=4)
    
    args = parser.parse_args()
    
    # 打印 args 检查参数是否存在 (调试用)
    # print("Args:", args) 

    device = torch.device(
        "cuda:{}".format(args.device_id) if args.device_id >= 0 else "cpu"
    )

    # === 2. 获取模型 ===
    # 注意：如果是 int_infer 阶段，这里返回的就是已经 convert 过的定点模型
    model = get_model_fx(args.stage, args.model_path, device, march=args.march)

    # === 3. 执行任务 ===
    from common import main as run_task

    if args.stage == "int_infer":
        print("Starting inference with Quantized Model...")
        model.eval()
        
        # === 核心修改：智能处理保存目录 ===
        if os.path.isfile(args.model_path):
            # 如果传入的是具体的文件路径 (如 .ckpt)，则取其父目录
            save_dir = os.path.dirname(args.model_path)
        else:
            # 如果传入的是目录，直接使用
            save_dir = args.model_path
            
        # 确保目录存在 (虽然通常都存在)
        os.makedirs(save_dir, exist_ok=True)
            
        save_path = os.path.join(save_dir, "quantized_model.pth")
        
        # 保存模型
        torch.save(model.state_dict(), save_path)
        print(f"定点模型已保存至: {save_path}")

        # (可选) 编译模型
        # from horizon_plugin_pytorch import compile_model
        # compile_model(model, [torch.randn(1, 3, 640, 640).to(device)], opt=2)
