import torch
import sys
import os
import cv2
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import random
from pathlib import Path
from ultralytics.utils.tal import make_anchors, dist2bbox
from ultralytics.nn.modules.block import DFL

try:
    from fx_mode_yolov8 import YOLOv8Model_fx
except ImportError:
    print("❌ 错误：无法导入模型结构。请确保此脚本与 fx_mode_yolov8.py 在同目录。")
    sys.exit(1)

# ==========================================
# 🛠️ 核心工具函数：Letterbox & 坐标还原
# ==========================================

def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True):
    """YOLOv8 标准预处理：保持长宽比，填充灰边"""
    shape = im.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up
        r = min(r, 1.0)

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, 32), np.mod(dh, 32)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    dw /= 2  # divide padding into 2 sides
    dh /= 2

    if shape[::-1] != new_unpad:  # resize
        im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    return im, ratio, (dw, dh)

def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
    """将坐标从 Letterbox 后的尺寸还原回原图尺寸"""
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    coords[:, [0, 2]] -= pad[0]  # x padding
    coords[:, [1, 3]] -= pad[1]  # y padding
    coords[:, :4] /= gain
    clip_coords(coords, img0_shape)
    return coords

def clip_coords(boxes, shape):
    if isinstance(boxes, torch.Tensor):  # faster individually
        boxes[:, 0].clamp_(0, shape[1])  # x1
        boxes[:, 1].clamp_(0, shape[0])  # y1
        boxes[:, 2].clamp_(0, shape[1])  # x2
        boxes[:, 3].clamp_(0, shape[0])  # y2
    else:  # np.array
        boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1])  # x1, x2
        boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0])  # y1, y2

def generate_colors(num_classes):
    random.seed(42)
    colors = []
    for _ in range(num_classes):
        b = random.randint(50, 255)
        g = random.randint(50, 255)
        r = random.randint(50, 255)
        colors.append((b, g, r))
    return colors

def xywh2xyxy(x):
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = x[..., 0] - x[..., 2] / 2
    y[..., 1] = x[..., 1] - x[..., 3] / 2
    y[..., 2] = x[..., 0] + x[..., 2] / 2
    y[..., 3] = x[..., 1] + x[..., 3] / 2
    return y

# ==========================================
# ⚙️ 核心逻辑：NMS & Mask Processing
# ==========================================

def simple_nms_with_mask(prediction, conf_thres=0.25, iou_thres=0.45, nc=9, nm=32):
    """Class-Specific NMS using Offset"""
    pred = prediction[0]
    
    box_wh = pred[:, :4]
    box_xyxy = xywh2xyxy(box_wh)
    cls_scores = pred[:, 4:4+nc]
    mask_coef = pred[:, 4+nc:] 
    
    conf, class_id = cls_scores.max(1)
    mask = conf > conf_thres
    
    box_xyxy = box_xyxy[mask]
    conf = conf[mask]
    class_id = class_id[mask]
    mask_coef = mask_coef[mask]
    
    if box_xyxy.shape[0] == 0:
        return []

    # Offset Trick for Multi-Class NMS
    max_wh = 7680
    c = class_id * max_wh
    boxes_for_nms = box_xyxy + c.unsqueeze(1)
    
    keep_indices = torchvision.ops.nms(boxes_for_nms, conf, iou_thres)
    
    result = torch.cat([
        box_xyxy[keep_indices],           # [N, 4]
        conf[keep_indices].unsqueeze(1),  # [N, 1]
        class_id[keep_indices].float().unsqueeze(1), # [N, 1]
        mask_coef[keep_indices]           # [N, 32]
    ], dim=1)
    
    return result

def process_mask_native(protos, masks_in, bboxes, shape, pad_info, pred_cls=None):
    """
    专门适配 Letterbox 的 Mask 还原逻辑
    """
    c, mh, mw = protos.shape  # CHW
    ih, iw = shape
    
    # 1. 矩阵乘法生成 Mask (N, 160, 160)
    masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)

    # 2. 放大到 640x640 (包含灰边)
    masks = F.interpolate(masks[None], size=(640, 640), mode='bilinear', align_corners=False)[0]

    # 3. 去除灰边 (Inverse Letterbox)
    ratio, (dw, dh) = pad_info
    top, bottom = int(round(dh - 0.1)), int(round(640 - dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(640 - dw + 0.1))
    masks = masks[:, top:bottom, left:right]

    # 4. 缩放回原图尺寸
    masks = F.interpolate(masks[None], size=(ih, iw), mode='bilinear', align_corners=False)[0]

    # 5. Crop 逻辑 (智能裁剪)
    for i, box in enumerate(bboxes):
        # 获取类别
        cls_id = int(pred_cls[i]) if pred_cls is not None else -1
        
        # 🚨【关键】如果是行车区域(2)或自车掩码(8)，跳过裁剪！
        # 这样即使框不准，Mask 也不会被切掉
        if cls_id in [2, 8]: 
            continue 

        x1, y1, x2, y2 = box.cpu().numpy().astype(int)
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(iw, x2), min(ih, y2)
        
        crop_mask = torch.zeros_like(masks[i])
        crop_mask[y1:y2, x1:x2] = 1
        masks[i] = masks[i] * crop_mask

    return masks.gt(0.5)

# ==========================================
# FloatModelWrapper (保持不变)
# ==========================================
class FloatModelWrapper(nn.Module):
    def __init__(self, original_model, nc=9):
        super().__init__()
        self.model = original_model
        self.names = {
            0: 'Barrier', 1: 'Bollard', 2: 'Driving_area', 
            3: 'N_Obstacle', 4: 'Obstacle', 5: 'Parking_lock', 
            6: 'Road_cone', 7: 'curb', 8: 'mask'
        }
        self.stride = torch.tensor([8., 16., 32.]) 
        self.nc = nc
        self.nm = 32   
        self.reg_max = 16 
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    def fuse(self, verbose=False): return self
    def info(self, verbose=False, img_size=None): pass
    def _apply(self, fn):
        self = super()._apply(fn)
        self.model = self.model._apply(fn)
        self.dfl = self.dfl._apply(fn)
        return self

    def forward(self, x, *args, **kwargs):
        if x.max() > 5.0:
            x = x.float() / 255.0
            
        outputs = self.model(x)
        proto = outputs[-1]

        feats_box, feats_score, feats_mask = [], [], []
        layer_indices = [(0, 1, 2), (3, 4, 5), (6, 7, 8)]

        for idx_score, idx_box, idx_mask in layer_indices:
            score = outputs[idx_score] 
            box   = outputs[idx_box]   
            mask  = outputs[idx_mask]  
            B, H, W, _ = score.shape
            
            feats_score.append(score.view(B, H*W, -1).permute(0, 2, 1)) 
            feats_box.append(box.view(B, H*W, -1).permute(0, 2, 1))     
            feats_mask.append(mask.view(B, H*W, -1).permute(0, 2, 1))   

        x_score = torch.cat(feats_score, dim=2) 
        x_box   = torch.cat(feats_box, dim=2)   
        x_mask  = torch.cat(feats_mask, dim=2)  

        fake_feats = [outputs[0].permute(0,3,1,2), outputs[3].permute(0,3,1,2), outputs[6].permute(0,3,1,2)]
        anchors, strides = make_anchors(fake_feats, self.stride, 0.5)
        anchors = anchors.transpose(0, 1) 
        strides = strides.transpose(0, 1) 

        pred_dist = self.dfl(x_box)
        pred_bboxes = dist2bbox(pred_dist, anchors.unsqueeze(0), xywh=True, dim=1) 
        pred_bboxes = pred_bboxes * strides 
        
        pred_scores = x_score.sigmoid()
        preds = torch.cat((pred_bboxes, pred_scores, x_mask), dim=1) 
        
        return preds.contiguous(), proto.contiguous()

# ==========================================
# 🚀 执行批量预测
# ==========================================
def batch_predict_and_visualize():
    # --- 🛠️ 用户配置区域 🛠️ ---
    ckpt_path = "/data6/liuziyi/yolov8_qat/qat_project/yolov8_multi_task/float-checkpoint.ckpt"
    img_dir = "/data6/liuziyi/yolov8_qat/qat_project/cali_4_camera"
    output_dir = "/data6/liuziyi/yolov8_qat/qat_project/yolov8_multi_task/visualize_pic/1"
    
    conf_thres = 0.15      
    iou_thres = 0.45      
    device_str = "cuda:0"
    # ---------------------------

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    device = torch.device(device_str)
    print(f"🚀 启动批量可视化 (Letterbox + SmartCrop)...")

    # 1. 加载模型
    print("📥 加载模型中...")
    raw_model = YOLOv8Model_fx(nc=9).to(device)
    if not os.path.exists(ckpt_path):
        print(f"❌ 权重不存在: {ckpt_path}")
        return
        
    checkpoint = torch.load(ckpt_path, map_location=device)
    state_dict = checkpoint.get('state_dict', checkpoint)
    new_state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
    raw_model.load_state_dict(new_state_dict, strict=False)
    raw_model.eval()
    
    model = FloatModelWrapper(raw_model, nc=9).to(device)
    model.eval()
    colors = generate_colors(model.nc)

    # 2. 遍历
    valid_exts = ('.jpg', '.jpeg', '.png', '.bmp')
    img_files = [f for f in os.listdir(img_dir) if f.lower().endswith(valid_exts)]
    
    print(f"📸 找到 {len(img_files)} 张图片，开始处理...")

    for idx, img_file in enumerate(img_files):
        img_path = os.path.join(img_dir, img_file)
        
        # --- 图像预处理 (Letterbox) ---
        img_raw = cv2.imread(img_path)
        if img_raw is None: continue
        h_orig, w_orig = img_raw.shape[:2]
        
        # 使用 Letterbox 替代简单的 Resize
        img_in, ratio, (dw, dh) = letterbox(img_raw, new_shape=(640, 640), auto=False)
        
        img_tensor = img_in[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB
        img_tensor = np.ascontiguousarray(img_tensor)
        img_tensor = torch.from_numpy(img_tensor).to(device).float()
        img_tensor /= 255.0
        if img_tensor.ndimension() == 3:
            img_tensor = img_tensor.unsqueeze(0)

        # --- 推理 ---
        with torch.no_grad():
            preds, proto = model(img_tensor)

        # --- NMS ---
        preds = preds.permute(0, 2, 1) 
        det_res = simple_nms_with_mask(preds, conf_thres, iou_thres, nc=model.nc, nm=model.nm)
        
        # 在原图上画，而不是画在 resize 的图上
        vis_img = img_raw.copy()
        
        if len(det_res) > 0:
            pred_boxes = det_res[:, :4]
            pred_conf = det_res[:, 4]
            pred_cls = det_res[:, 5]
            pred_masks_coef = det_res[:, 6:]
            
            # --- 坐标还原 ---
            # 把 640 坐标系的框还原回原图坐标系
            # pred_boxes = scale_coords(img_tensor.shape[2:], pred_boxes, img_raw.shape, ratio_pad=(ratio, (dw, dh))).round()

            # --- 坐标还原 ---
            # 【核心修改点】: 计算 Letterbox 实际使用的整数 Padding
            # OpenCV 的 copyMakeBorder 使用的是 int(round(x - 0.1)) 来决定左/上边界
            real_pad_w = int(round(dw - 0.1))
            real_pad_h = int(round(dh - 0.1))

            # 把 640 坐标系的框还原回原图坐标系
            # 注意：这里传入的是 (real_pad_w, real_pad_h) 而不是浮点的 (dw, dh)
            pred_boxes = scale_coords(
                img_tensor.shape[2:], 
                pred_boxes, 
                img_raw.shape, 
                ratio_pad=(ratio, (real_pad_w, real_pad_h))
            ).round()

            # ==========================================
            # ⬇️⬇️⬇️ 暴力修正区域 (Brute Force Fix) ⬇️⬇️⬇️
            # ==========================================
            
            # 设定偏移量 (单位：像素)
            # 现象是 "框上移" (太高了)，所以我们需要把它 "降下来"。
            # 图像坐标系 Y 轴向下是正方向，所以要 + 50。
            # 如果你确实想 "减去" 50 (让框往飞得更高)，改成 -50 即可。
            manual_y_offset = 0 

            # 对 y1 (top) 和 y2 (bottom) 同时加上偏移量
            pred_boxes[:, [1, 3]] += manual_y_offset
            
            # (可选) 再次限制坐标，防止框移出图片底部
            clip_coords(pred_boxes, img_raw.shape)
            
            # --- Mask 处理 (去灰边 + 智能裁剪) ---
            proto = proto[0] 
            masks = process_mask_native(proto, pred_masks_coef, pred_boxes, (h_orig, w_orig), (ratio, (dw, dh)), pred_cls=pred_cls)
            
            # --- 绘图 ---
            mask_overlay = vis_img.copy()
            for i, mask in enumerate(masks):
                cls_id = int(pred_cls[i])
                color = colors[cls_id]
                mask_np = mask.cpu().numpy().astype(bool)
                mask_overlay[mask_np] = color 
            
            cv2.addWeighted(mask_overlay, 0.4, vis_img, 0.6, 0, vis_img)
            
            for i, box in enumerate(pred_boxes):
                x1, y1, x2, y2 = map(int, box.cpu().numpy())
                cls_id = int(pred_cls[i])
                conf = float(pred_conf[i])
                label_name = model.names.get(cls_id, str(cls_id))
                color = colors[cls_id]

                cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
                
                label_text = f"{label_name} {conf:.2f}"
                t_size = cv2.getTextSize(label_text, 0, fontScale=0.6, thickness=1)[0]
                c2 = x1 + t_size[0], y1 - t_size[1] - 5
                # c2 = x1 + t_size[0], y1 - t_size[1]
                cv2.rectangle(vis_img, (x1, y1), c2, color, -1)
                cv2.putText(vis_img, label_text, (x1, y1 - 2), 0, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
        
        save_file = os.path.join(output_dir, img_file.replace(".png", "_vis.jpg").replace(".jpg", "_vis.jpg"))
        cv2.imwrite(save_file, vis_img)
        
        sys.stdout.write(f"\r✅ 处理进度: [{idx+1}/{len(img_files)}] - 保存至 {save_file}")
        sys.stdout.flush()

    print(f"\n🎉 全部完成！请查看: {output_dir}")

if __name__ == "__main__":
    batch_predict_and_visualize()

