import torch
import sys
import os
import argparse
import torch.nn.functional as F
import torch.nn as nn
import ultralytics.utils.checks
def dummy_check_font(*args, **kwargs):
    pass
ultralytics.utils.checks.check_font = dummy_check_font

from ultralytics.models.yolo.segment import SegmentationValidator
# ==============================================================================
# 🚑【紧急补丁 V3】Ultralytics Mask 形状自动修复 + 残缺数据填充
#    修复：1. 强制重塑 2D -> 3D
#          2. 兼容 shape=[640]
#          3. 自动检测并修复“通道丢失” (25600 -> 819200)
# ==============================================================================
import ultralytics.utils.ops as ops_module
import ultralytics.models.yolo.segment.val as val_module
import torch

import copy
from horizon_plugin_pytorch.quantization import prepare_qat_fx
from horizon_plugin_pytorch.quantization.qconfig import (default_calib_8bit_fake_quant_qconfig,default_qat_8bit_fake_quant_qconfig)
from horizon_plugin_pytorch.march import set_march, March
# 1. 备份原函数
if not hasattr(ops_module, "_raw_process_mask"):
    ops_module._raw_process_mask = ops_module.process_mask

# 2. 定义终极强壮的新函数
def safe_process_mask(protos, masks_in, bboxes, shape, upsample=False):
    """
    终极版 process_mask: 
    - 自动处理 shape 只有一个值的情况
    - 自动修复 2D Proto
    - 自动检测数据缺失并进行复制填充 (Fix Last Batch Bug)
    """
    # --- [修复 1] 鲁棒地解析 H, W ---
    try:
        if hasattr(shape, '__len__'): 
            if len(shape) == 2: H_in, W_in = shape
            elif len(shape) == 1: H_in, W_in = shape[0], shape[0]
            else: H_in, W_in = 640, 640
        else: 
            H_in, W_in = int(shape), int(shape)
    except:
        H_in, W_in = 640, 640
    
    # 修正后的 shape
    safe_shape = (H_in, W_in)

    # --- [修复 2 & 3] 拦截异常 Proto ---
    # 只要不是标准的 3 维 [32, H, W]，就进入修复流程
    if protos.dim() != 3 or protos.shape[0] != 32:
        
        # 1. 计算我们期望的目标尺寸
        target_h = H_in // 4
        target_w = W_in // 4
        expected_size = 32 * target_h * target_w  # 标准大小 (e.g. 819200)
        actual_size = protos.numel()              # 实际大小 (e.g. 25600)

        # 2. 诊断问题类型
        # 情况 A: 数据量完全正确，只是形状扁了
        if actual_size == expected_size:
            # print(f"🔧 [Patch] 修复扁平形状: {protos.shape} -> [32, {target_h}, {target_w}]")
            protos = protos.view(32, target_h, target_w)
            
        # 情况 B: 严重残缺！数据量只有 1/32 (你的报错就是这个!)
        elif actual_size * 32 == expected_size:
            print(f"⚠️ [Patch] 检测到通道丢失 (最后Batch)! 正在复制填充... {actual_size} -> {expected_size}")
            # 先变成 [1, H, W]
            protos = protos.view(1, target_h, target_w)
            # 复制 32 份变成 [32, H, W]
            protos = protos.repeat(32, 1, 1)
            
        # 情况 C: 其他奇怪尺寸，尝试根据像素数反推正方形
        else:
            side = int((actual_size / 32) ** 0.5)
            if 32 * side * side == actual_size:
                protos = protos.view(32, side, side)
            else:
                # 实在没救了，打印错误但不崩，返回一个空的或者随机的防止报错
                # 这样至少能跑完，只是这几张图没分
                print(f"❌ [Patch] 无法修复 Proto: 实际{actual_size} vs 期望{expected_size}. 跳过此掩码.")
                # 创建一个假的 Proto 防止报错
                protos = torch.zeros((32, target_h, target_w), device=protos.device)

    # 调用原函数
    return ops_module._raw_process_mask(protos, masks_in, bboxes, safe_shape, upsample)

# 3. 应用补丁
ops_module.process_mask = safe_process_mask
val_module.process_mask = safe_process_mask
print("\n✅ [System] Ultralytics process_mask 补丁 V3 (残缺填充版) 已激活！\n")
# ==============================================================================
from ultralytics.utils.tal import make_anchors, dist2bbox
from ultralytics.nn.modules.block import DFL

try:
    from fx_mode_yolov8 import YOLOv8Model_float,YOLOv8Model_fx
except ImportError:
    print("❌ 错误：无法导入模型结构。请确保此脚本与 fx_mode_yolov8.py 在同目录。")
    sys.exit(1)

class FloatModelWrapper(nn.Module):
    def __init__(self, original_model, nc=9):
        super().__init__()
        self.model = original_model
        self.names = {
            0: 'Barrier', 1: 'Bollard', 2: 'Driving_area', 
            3: 'N_Obstacle', 4: 'Obstacle', 5: 'Parking_lock', 
            6: 'Road_cone', 7: 'curb', 8: 'mask'
        }
        self.stride = torch.tensor([8., 16., 32.]) 
        self.nc = nc
        self.nm = 32   
        self.reg_max = 16 
        self.yaml = {
            'channels': 3, 
            'nc': nc, 
            'names': self.names, 
            'stride': self.stride
        }
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    # =========================================================
    # 骗过 AutoBackend 的必要方法
    # =========================================================
    def fuse(self, verbose=False): 
        return self  # 不做融合，返回自身
    
    def info(self, verbose=False, img_size=None): 
        pass # 占位符

    def _apply(self, fn):
        # 允许 .cpu(), .cuda(), .half() 等操作传递给内部模型
        self = super()._apply(fn)
        self.model = self.model._apply(fn)
        self.dfl = self.dfl._apply(fn)
        return self

    def forward(self, x, *args, **kwargs):
        # 1. 基础归一化
        if x.max() > 5.0:
            x = x.float() / 255.0

        # 2. 推理
        outputs = self.model(x)
        proto = outputs[-1]

        if proto.dim() != 4:
            B, _, H_in, W_in = x.shape
            target_h = H_in // 4
            target_w = W_in // 4
            
            # 如果是 [32, N] (Batch=1 被 squeeze 的情况)
            if proto.dim() == 2:
                proto = proto.unsqueeze(0) # [1, 32, N]
            
            # 检查是否需要转置 (处理 Channels Last [B, N, 32])
            if proto.shape[-1] == 32:
                proto = proto.permute(0, 2, 1) # 变成 [B, 32, N]
            try:
                proto = proto.view(B, 32, target_h, target_w)
            except RuntimeError as e:
                print(f"\n❌ [CRASH] 强制 Reshape 失败！")
                print(f"  输入图像: {x.shape}")
                print(f"  目标形状: [{B}, 32, {target_h}, {target_w}] (总元素 {B*32*target_h*target_w})")
                print(f"  实际 Proto: {proto.shape} (总元素 {proto.numel()})")
                print(f"  差距: {proto.numel() - B*32*target_h*target_w} 个元素")
                raise e
        # =======================================================

        # 3. 后处理拼接 (保持不变)
        feats_box, feats_score, feats_mask = [], [], []
        layer_indices = [(0, 1, 2), (3, 4, 5), (6, 7, 8)]

        for idx_score, idx_box, idx_mask in layer_indices:
            score, box, mask = outputs[idx_score], outputs[idx_box], outputs[idx_mask]
            B, H, W, _ = score.shape
            
            feats_score.append(score.view(B, H*W, -1).permute(0, 2, 1)) 
            feats_box.append(box.view(B, H*W, -1).permute(0, 2, 1))     
            feats_mask.append(mask.view(B, H*W, -1).permute(0, 2, 1))   

        x_score = torch.cat(feats_score, dim=2) 
        x_box   = torch.cat(feats_box, dim=2)   
        x_mask  = torch.cat(feats_mask, dim=2)  

        fake_feats = [outputs[0].permute(0,3,1,2), outputs[3].permute(0,3,1,2), outputs[6].permute(0,3,1,2)]
        anchors, strides = make_anchors(fake_feats, self.stride, 0.5)
        anchors, strides = anchors.transpose(0, 1), strides.transpose(0, 1)

        pred_dist = self.dfl(x_box)
        pred_bboxes = dist2bbox(pred_dist, anchors.unsqueeze(0), xywh=True, dim=1) * strides
        pred_scores = x_score.sigmoid()
        
        preds = torch.cat((pred_bboxes, pred_scores, x_mask), dim=1) 
        
        return preds.contiguous(), proto.contiguous()
# ================= 5. 浮点评估主函数 =================
# def run():
#     # --- 配置区域 ---
#     ckpt_path = "/data6/liuziyi/yolov8_qat/qat_project/yolov8_multi_task/calib-checkpoint.ckpt"
#     data_yaml = "/data6/liuziyi/yolov8_qat/qat_project/data2.yaml"
#     device = torch.device("cuda:0")
    
#     print(f"🚀 开始评估浮点模型...")
    
#     # 1. 实例化模型结构
#     # raw_model = YOLOv8Model_float(nc=9).to(device)
#     raw_model = YOLOv8Model_fx(nc=9).to(device)
    
#     # 2. 加载权重
#     if os.path.exists(ckpt_path):
#         print(f"📥 正在加载权重: {ckpt_path}")
#         checkpoint = torch.load(ckpt_path, map_location=device)
#         state_dict = checkpoint.get('state_dict', checkpoint)
#         new_state_dict = {}
#         # 去除前缀
#         for k, v in state_dict.items():
#             key = k.replace('model.', '')
#             new_state_dict[key] = v
        
#         # 加载
#         raw_model.load_state_dict(new_state_dict, strict=False)
#         print("✅ 权重加载完成")
#     else:
#         print(f"❌ 错误：找不到权重文件 {ckpt_path}")
#         return

#     # 3. 包装模型
#     raw_model.eval()
#     model = FloatModelWrapper(raw_model, nc=9)
    
#     # 4. 启动 Ultralytics 验证器
#     args = dict(
#         model='yolov8n-seg.pt', # 占位符
#         data=data_yaml,
#         imgsz=640,
#         batch=4,
#         split='val',
#         device='4',
#         plots=False,
#         save_json=False,
#         half=False
#     )
    
#     validator = SegmentationValidator(args=args)
#     validator.model = model
#     validator.device = device
#     validator.nc = 9
#     validator.names = model.names
    
#     # 开始验证
#     validator(model=model)
# ================= 5. qat后评估主函数 =================
def run():
    # --- 配置区域 ---
    ckpt_path = "/data6/liuziyi/yolov8_qat/qat_project/yolov8_multi_task/qat-best-checkpoint.ckpt"
    data_yaml = "/data6/liuziyi/yolov8_qat/qat_project/data2.yaml"
    device = torch.device("cuda:0") # 注意这里要和你 args 里的 device 对应
    
    # ！！！必须设置 BPU 架构，否则 prepare_qat_fx 会报错或行为异常
    set_march(March.BERNOULLI2) 

    print(f"🚀 开始评估 Calib 模型...")
    
    # 1. 实例化原始浮点结构 (这只是个底座)
    float_model = YOLOv8Model_fx(nc=9)
    
    # 2. 【关键步骤】复刻 Calib 阶段的模型结构变换
    # 这一步会将 float_model 进行算子融合，并插入伪量化节点
    # 只有经过这一步，模型的结构才能和 ckpt 里的权重对上号
    print("🔄 正在构建 FX Graph (prepare_qat_fx)...")
    # calib_model = prepare_qat_fx(
    #     float_model,
    #     {
    #         "": default_calib_8bit_fake_quant_qconfig, # 必须用和训练时一样的 config
    #     },
    # ).to(device)
    qat_model = prepare_qat_fx(
        float_model,
        {
            "": default_calib_8bit_fake_quant_qconfig, # 必须用和训练时一样的 config
        },
    ).to(device)
    
    # 3. 加载权重
    if os.path.exists(ckpt_path):
        print(f"📥 正在加载权重: {ckpt_path}")
        checkpoint = torch.load(ckpt_path, map_location=device)
        
        # 处理可能的 state_dict 嵌套
        state_dict = checkpoint.get('state_dict', checkpoint)
        
        # 这里的 replace 逻辑可能需要，也可能不需要，取决于你保存时是否有 'model.' 前缀
        # 建议先打印一下 keys 看看
        # print("Checkpoint Keys:", list(state_dict.keys())[:5])
        # print("Model Keys:", list(calib_model.state_dict().keys())[:5])
        
        # 尝试加载 (这时 strict=True 应该大部分能过，或者有少量无关 key)
        # 如果保存时没有多余前缀，直接 load 即可
        try:
            qat_model.load_state_dict(state_dict, strict=True)
            print("✅ 权重完美加载 (Strict=True)")
        except RuntimeError as e:
            print("⚠️ 权重加载有不匹配 (尝试 Strict=False + 前缀处理)...")
            # 如果上面失败，保留你之前的去前缀逻辑再试一次
            new_state_dict = {}
            for k, v in state_dict.items():
                key = k.replace('model.', '') # 如果训练代码用了 DataParallel 可能有这个
                new_state_dict[key] = v
            qat_model.load_state_dict(new_state_dict, strict=False)
            print("✅ 权重加载完成 (Strict=False)")
            
    else:
        print(f"❌ 错误：找不到权重文件 {ckpt_path}")
        return

    # 4. 包装模型
    # calib_model.eval()
    qat_model.eval()
    # ！！！重要：FloatModelWrapper 里的 forward 逻辑可能需要适配
    # 因为现在的 model 已经是 GraphModule，不是原来的 class 了
    model = FloatModelWrapper(qat_model, nc=9)
    
    # 5. 启动 Ultralytics 验证器
    args = dict(
        model='yolov8n-seg.pt', # 占位符
        data=data_yaml,
        imgsz=640,
        batch=4,
        split='val',
        device='0',
        plots=False,
        save_json=False,
        half=False
    )
    
    validator = SegmentationValidator(args=args)
    validator.model = model
    validator.device = device
    validator.nc = 9
    validator.names = model.names
    
    # 开始验证
    print("⚡ 开始运行验证...")
    validator(model=model)

if __name__ == "__main__":
    run()

