import argparse
import os

import copy
import torch
import torch.nn as nn
import torch.quantization
import torchvision.transforms as transforms
from torch import Tensor
from torch.quantization import DeQuantStub
from torchvision.models.mobilenetv2 import (
    MobileNetV2,
)
from horizon_plugin_pytorch.march import March, set_march
from horizon_plugin_pytorch.quantization import (
    QuantStub,
    convert_fx,
    prepare_qat_fx,
)
from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,
    default_qat_8bit_fake_quant_qconfig,
    default_qat_8bit_weight_32bit_out_fake_quant_qconfig,
    default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
)
from typing import Optional, Callable, List, Tuple
from horizon_plugin_pytorch.nn.quantized import FloatFunctional
from common import *
from yolov8_utils.modules import *

torch.manual_seed(191009)


import sys
import os
import torch
import ultralytics
import ultralytics.nn.modules 
sys.modules['ultralytics.yolo'] = ultralytics
print("✅ Patch 1 Applied: Redirected 'ultralytics.yolo' -> 'ultralytics'")
try:
    from yolov8_utils.modules import (
        X3Conv, X3SGConv, X3SPPF, X3VarGBlock, X3C2f, 
        X3ConvTranspose, X3Detect, X3Proto, X3Segment, X3Pose,
        Concat
    )
except ImportError:
    sys.path.append(os.path.dirname(os.path.abspath(__file__)))
    from yolov8_utils.modules import (
        X3Conv, X3SGConv, X3SPPF, X3VarGBlock, X3C2f, 
        X3ConvTranspose, X3Detect, X3Proto, X3Segment, X3Pose,
        Concat
    )

patch_classes = [
    X3Conv, X3SGConv, X3SPPF, X3VarGBlock, X3C2f, 
    X3ConvTranspose, X3Detect, X3Proto, X3Segment, X3Pose,
    Concat
]

print("🛠️ Applying Monkey Patch for Custom X3 Operators...")
for cls in patch_classes:
    setattr(ultralytics.nn.modules, cls.__name__, cls)
    print(f"  👉 Patched: ultralytics.nn.modules.{cls.__name__}")

print("✅ Monkey Patch Complete. Ready to load custom weights.\n")



def init_yolo_weights(model):
    """
    初始化模型权重，特别是检测头的 Bias，解决小目标检测难的问题。
    """
    import math
    print("🛠️ Executing YOLO-style weight initialization...")
    
    for m in model.modules():
        t = type(m)
        if t is nn.Conv2d:
            pass 
        elif t is nn.BatchNorm2d:
            m.eps = 1e-3
            m.momentum = 0.03
        elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
            m.inplace = True

    # 初始化 Detect Bias
    try:
        # 适配你的模型结构 model.head.segment_head
        if hasattr(model, 'head') and hasattr(model.head, 'segment_head'):
            seg_head = model.head.segment_head
            nc = seg_head.nc
            
            init_count = 0
            for name, m in seg_head.named_modules():
                # 寻找输出通道数为 nc 的卷积层
                if isinstance(m, nn.Conv2d) and m.out_channels == nc:
                    b = m.bias.view(1, -1)
                    # 设定初始概率 p=0.01 -> bias ≈ -4.6
                    b.data.fill_(-math.log((1 - 0.01) / 0.01))
                    m.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
                    init_count += 1
                    print(f"  ✅ Initialized bias for layer: {name}")
            
            if init_count == 0:
                print("  ⚠️ Warning: No classification layer found to initialize bias!")
            else:
                print(f"  ✅ Successfully initialized {init_count} detection layers.")
                
    except Exception as e:
        print(f"  ⚠️ Initialization skipped or failed: {e}")


def load_custom_yolo_weights(my_model, pt_path):
    """
    加载 Ultralytics 风格的权重。优先尝试加载 EMA 权重。
    """
    import torch
    print(f"📥 Loading weights from {pt_path}...")
    try:
        ckpt = torch.load(pt_path, map_location='cpu')
        
        # 🔥 优先尝试提取 EMA 权重 (通常在 'ema' 或 'model' 里的 ema 属性)
        # Ultralytics 的 ckpt['model'] 是一个 Model 对象，它可能有 ema 属性
        pretrained_dict = None
        
        # 1. 尝试直接获取 state_dict
        if 'model' in ckpt:
            model_obj = ckpt['model']
            # 如果是 deepcopy 出来的 ema 模型
            if hasattr(ckpt, 'get') and ckpt.get('ema'):
                print("  👉 Loading EMA weights (Best practice)...")
                model_obj = ckpt['ema']
            
            if hasattr(model_obj, 'float'):
                model_obj = model_obj.float() # 确保转为 float32
                
            if hasattr(model_obj, 'state_dict'):
                pretrained_dict = model_obj.state_dict()
            else:
                pretrained_dict = model_obj # 可能是纯字典
        else:
            pretrained_dict = ckpt # 假如它本身就是 state_dict
            
    except Exception as e:
        print(f"❌ Failed to load checkpoint: {e}")
        return

    # 2. 准备参数列表
    my_model_dict = my_model.state_dict()
    
    # 过滤 key
    pt_items = []
    for k, v in pretrained_dict.items():
        if 'num_batches_tracked' in k or 'anchor' in k: continue
        pt_items.append((k, v))

    my_items = []
    for k, v in my_model_dict.items():
        if 'num_batches_tracked' in k or 'anchor' in k: continue
        my_items.append((k, v))

    print(f"  📊 Source params: {len(pt_items)} | Target params: {len(my_items)}")

    # 3. 匹配与加载
    matched_count = 0
    new_state_dict = my_model_dict.copy()
    limit = min(len(pt_items), len(my_items))
    
    for i in range(limit):
        src_k, src_v = pt_items[i]
        dst_k, dst_v = my_items[i]
        
        # 形状匹配
        if src_v.shape == dst_v.shape:
            new_state_dict[dst_k] = src_v.float() # 再次确保 float
            matched_count += 1
        else:
            print(f"  ⚠️ Shape mismatch: Src {src_k}{src_v.shape} != Dst {dst_k}{dst_v.shape}")

    # 4. Load
    my_model.load_state_dict(new_state_dict, strict=False)
    print(f"✅ Successfully loaded {matched_count} layers from pretrained weights.")

##############################################################################
# At first, we do necessary modify to the MobilenetV2 model from torchvision.
# 1. Insert QuantStub before first layer and DequantStub after last layer.
# Operation replacement and fusion will be carried out automatically (^_^).
##############################################################################

class YOLOv8Backbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            X3Conv(3, 16, 3, 2),   # 0-P1/2
            X3Conv(16, 32, 3, 2),  # 1-P2/4
            X3C2f(32, 32, 1, True),         # 2
            X3Conv(32, 64, 3, 2),  # 3-P3/8
            X3C2f(64, 64, 2, True),       # 4
            X3Conv(64, 128, 3, 2),  # 5-P4/16
            X3C2f(128, 128, 2, True),
            X3Conv(128, 256, 3, 2), # 7-P5/32
            X3C2f(256, 256, 1, True),
            X3SPPF(256, 256, 5)     # 9
        ])
    
    def forward(self, x):
        outputs = []
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i in [4, 6, 9]:
                outputs.append(x)  # 保存各阶段特征
        return outputs  # [P4, P6, final]

class YOLOv8Head(nn.Module):
    def __init__(self, nc=9):
        super().__init__()

        self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.concat1 = Concat(1)
        self.c2f1 = X3C2f(384, 128, 1)

        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.concat2 = Concat(1)
        self.c2f2 = X3C2f(192, 64, 1)

        self.down1 = X3Conv(64, 64, 3, 2)
        self.concat3 = Concat(1)
        self.c2f3 = X3C2f(192, 128, 1)

        self.down2 = X3Conv(128, 128, 3, 2)
        self.concat4 = Concat(1)
        self.c2f4 = X3C2f(384, 256, 1)

        self.segment_head = X3Segment(nc, 32, 256, (64, 128, 256))
        self.quant = QuantStub(scale=1 / 128)
        self.dequant = DeQuantStub()

    def forward(self, features):
        # features = backbone 输出 [P4, P6, P9]
        P4, P6, P9 = features[-3], features[-2], features[-1]

        x = self.up1(P9)
        x = self.concat1([x, P6])
        x = self.c2f1(x)

        y = self.up2(x)
        y = self.concat2([y, P4])
        y_p3 = self.c2f2(y)

        y = self.down1(y_p3)
        y = self.concat3([y, x])
        y_p4 = self.c2f3(y)

        y = self.down2(y_p4)
        y = self.concat4([y, P9])
        y_p5 = self.c2f4(y)

        out = self.segment_head([y_p3, y_p4, y_p5])

        # detect_feats = [y_p3, y_p4, y_p5]  # ⭐ 新增：给 grid / anchor 用
        return out

class YOLOv8Model_fx(nn.Module):
    def __init__(self, nc=9):
        super().__init__()
        self.backbone = YOLOv8Backbone()
        self.head = YOLOv8Head(nc=nc)
        self.quant = QuantStub(scale=1 / 128)
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        features = self.backbone(x)
        out = self.head(features)
        return tuple(self.dequant(t) for t in out)

class YOLOv8Model_float(nn.Module):
    def __init__(self, nc=9):
        super().__init__()
        self.backbone = YOLOv8Backbone()
        self.head = YOLOv8Head(nc=nc)

    def forward(self, x):
        features = self.backbone(x)
        out = self.head(features)
        return tuple(t for t in out)
    # def forward(self, x):
    #     x = self.quant(x)
    #     features = self.backbone(x)
        
    #     # head 返回的是一个包含 10 个 Tensor 的 tuple
    #     # (Score_S, Box_S, Coef_S, Score_M, ..., Proto)
    #     out = self.head(features) 
        
    #     # 我们只需要简单地对这 10 个输出分别做去量化 (DeQuant)
    #     # 然后保持元组形式返回即可
    #     return tuple(self.dequant(t) for t in out)


def get_model_fx(
    stage: str,
    model_path: str,
    device: torch.device,
    march=March.BERNOULLI2,
) -> nn.Module:
    assert stage in ("float", "calib", "qat", "int_infer", "compile")
    model_kwargs = dict(nc=9)

    float_model = YOLOv8Model_fx(**model_kwargs)
    #float_model = FxQATReadyMobileNetV2(**model_kwargs).to(device)

    if stage == "float":
        # We also could use the origin MobileNetV2 model for float training,
        # because modified QAT ready model can load its params seamlessly.
        float_model = YOLOv8Model_float(**model_kwargs).to(
            device
        )  # these lines are optional

        # Load pretrained model (on ImageNet) to speed up float training.
        # load_pretrain(float_model, model_path)

        return float_model

    float_ckpt_path = os.path.join(model_path, "float-checkpoint.ckpt")
    assert os.path.exists(float_ckpt_path)
    float_state_dict = torch.load(float_ckpt_path, map_location=device)

    # A global march indicating the target hardware version must be setted
    # before prepare qat.
    set_march(march)

    # Preserve a clean float_model for calibration and qat training.
    ori_float_model = float_model
    float_model = copy.deepcopy(ori_float_model)

    float_model.load_state_dict(float_state_dict)
    # The op fusion is included in `prepare_qat_fx`.
    # We can eigher pass qconfig as dict in prepare_qat_fx, or set
    # module's qconfig attr in the same manner as eager mode.
    calib_model = prepare_qat_fx(
        float_model,
        {
            "": default_calib_8bit_fake_quant_qconfig,
            # "module_name": {
            #     "yolov8_multitask": default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
            # },
        },
    ).to(device)

    if stage == "calib":
        return calib_model

    calib_ckpt_path = os.path.join(model_path, "calib-checkpoint.ckpt")
    if os.path.exists(calib_ckpt_path):
        calib_state_dict = torch.load(calib_ckpt_path, map_location=device)
        float_state_dict = None
    else:
        calib_state_dict = None

    float_model = copy.deepcopy(ori_float_model)

    if float_state_dict is not None:
        float_model.load_state_dict(float_state_dict)
    qat_model = prepare_qat_fx(
        float_model,
        {
            "": default_qat_8bit_fake_quant_qconfig,
            # "module_name": {
            #     "classifier": default_qat_8bit_weight_32bit_out_fake_quant_qconfig,
            # },
        },
    ).to(device)
    if calib_state_dict is not None:
        qat_model.load_state_dict(calib_state_dict)

    if stage == "qat":
        return qat_model

    qat_ckpt_path = os.path.join(model_path, "qat-checkpoint.ckpt")
    if os.path.exists(qat_ckpt_path):
        qat_model.load_state_dict(
            torch.load(qat_ckpt_path, map_location=device)
        )
    elif os.path.exists(calib_ckpt_path):
        calib_model.load_state_dict(
            torch.load(calib_ckpt_path, map_location=device)
        )
        # The qat_model and calib_model both can be converted to
        # quantized model directly.
        # So we can do either calibration or qat training or both.
        qat_model = calib_model
    else:
        raise FileNotFoundError(
            "Do not find saved calib_model or qat_model ckpt "
            "to do int inference."
        )

    quantized_model = convert_fx(qat_model).to(device)

    return quantized_model



if __name__ == "__main__":
    # args = get_args()
    parser = argparse.ArgumentParser()
    parser.add_argument("--stage", type=str, default="compile") # stage=['float','calib', 'qat', 'int_infer', 'compile']
    parser.add_argument("--model_path", type=str, default="/data6/liuziyi/yolov8_qat/qat_project/yolov8_multi_task")
    parser.add_argument("--save_path", type=str, default="/data6/liuziyi/yolov8_qat/qat_project/output/0119_qat_pyramid", help="用来保存编译文件")
    # parser.add_argument("--img_dir", type=str, default="/data6/liuziyi/yolov8_qat/qat_project/yolov8n_seg_cali", help="calib用")
    parser.add_argument("--data_yaml", type=str, default="/data6/liuziyi/yolov8_qat/qat_project/data2.yaml")
    parser.add_argument("--train_batch_size", type=int, default=32)
    parser.add_argument("--eval_batch_size", type=int, default=1)
    parser.add_argument("--calib_batch", type=int, default=1)
    parser.add_argument("--epoch_num", type=int, default=10)
    parser.add_argument("--device_id", type=int, default=0)
    parser.add_argument("--march", type=str, default='BERNOULLI2')
    parser.add_argument("--imgsz", type=int, default=640)
    parser.add_argument("--workers", type=int, default=4)
    parser.add_argument("--opt", type=str, default="0")

    
    args = parser.parse_args()
    device = device = torch.device(
        "cuda:{}".format(args.device_id) if args.device_id >= 0 else "cpu"
    )
    # model =  get_model_fx(args.stage, args.model_path, device)

    # print(f"🔄 Initializing Model for March: {args.march} ...")
    # model = get_model_fx(args.stage, args.model_path, device, march=args.march)
    model = get_model_fx(args.stage, args.model_path, device)
#-------------------------------------------------------------------------------------------------------------  
    if args.stage == "float":
        float_ckpt = os.path.join(args.model_path, "float-checkpoint.ckpt")
        # 只有当不存在 checkpoint 时才初始化，防止覆盖掉你训练了一半的模型
        if not os.path.exists(float_ckpt):
            print("\n [STARTUP] Detected Scratch Training. Applying YOLO Initialization...")
            init_yolo_weights(model)

            # 2. 加载你自己用官方框架训练好的权重 (例如 best.pt)
            # 假设你把权重放在了这里
            my_trained_pt = "/data6/liuziyi/yolov8_qat/qat_project/yolov8_multi_task/best_xj3_1226.pt" 
            
            if os.path.exists(my_trained_pt):
                print(" Found pretrained weights from official framework! Loading...")
                load_custom_yolo_weights(model, my_trained_pt)
        else:
            print("\n [STARTUP] Found existing checkpoint. Skipping initialization to resume training.")
#-------------------------------------------------------------------------------------------------------------
    main(
        args,
        model,
        args.stage,
        args.data_yaml,
        args.model_path,
        args.train_batch_size,
        args.eval_batch_size,
        args.epoch_num,
        args.device_id,
        march=args.march,
        compile_opt=args.opt,
        save_path=args.save_path
    )
