import argparse
import os

from ultralytics.data.build import build_dataloader
from ultralytics.data.utils import check_det_dataset

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from pathlib import Path
import random

import math
import copy
import torch
import torch.nn as nn
import torch.quantization
import torchvision.transforms as transforms
from torch import Tensor
from torch.quantization import DeQuantStub
from torchvision.datasets import CIFAR10
from torchvision.models.mobilenetv2 import (
    MobileNet_V2_Weights
)
from torchvision._internally_replaced_utils import load_state_dict_from_url
from torch.utils import data
from horizon_plugin_pytorch.march import March
from horizon_plugin_pytorch.quantization import (
    QuantStub,
    set_fake_quantize,
    FakeQuantState,
    check_model,
    compile_model,
    perf_model,
    visualize_model,
)
from horizon_plugin_pytorch.functional import centered_yuv2rgb
from typing import Optional, Callable, List, Tuple
from horizon_plugin_pytorch.nn.quantized import FloatFunctional
import yaml
from torch.utils.data import Dataset, DataLoader
import cv2
from pathlib import Path
import json
import numpy as np
from multiprocessing.pool import ThreadPool
from itertools import repeat
from tqdm import tqdm
import logging.config
import hashlib
from yolov8_utils.loss_utils import SegLoss
from torch.cuda.amp import GradScaler, autocast

import torch.nn as nn
from ultralytics.models.yolo.segment import SegmentationValidator
from ultralytics.utils.tal import make_anchors, dist2bbox
from ultralytics.nn.modules.block import DFL
# import ultralytics.utils.checks
# ultralytics.utils.checks.check_font = dummy_check_font
# export LD_LIBRARY_PATH=/root/miniconda3/envs/qat_env/lib:$LD_LIBRARY_PATH
class InversePermuteWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, x):
        outs = self.model(x)
        new_outs = []
        for o in outs:
            # 假设输出是 4 维 Tensor，且已经是 NHWC 格式
            if o.dim() == 4:
                # 把它转回 NCHW: (0, 3, 1, 2)
                # 原理：Batch(0), Channel(3), Height(1), Width(2)
                new_outs.append(o.permute(0, 3, 1, 2))
            else:
                new_outs.append(o)
        return tuple(new_outs)

class FloatModelWrapper(nn.Module):
    def __init__(self, original_model, nc=9):
        super().__init__()
        self.model = original_model
        # 定义类别名称 (请根据你的 data.yaml 修改这里，或者做成动态传入)
        self.names = {
            0: 'Barrier', 1: 'Bollard', 2: 'Driving_area', 
            3: 'N_Obstacle', 4: 'Obstacle', 5: 'Parking_lock', 
            6: 'Road_cone', 7: 'curb', 8: 'mask'
        }
        self.stride = torch.tensor([8., 16., 32.]) 
        self.nc = nc
        self.nm = 32   
        self.reg_max = 16 
        self.yaml = {
            'channels': 3, 
            'nc': nc, 
            'names': self.names, 
            'stride': self.stride
        }
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    def fuse(self, verbose=False): return self 
    def info(self, verbose=False, img_size=None): pass 

    def _apply(self, fn):
        self = super()._apply(fn)
        self.model = self.model._apply(fn)
        self.dfl = self.dfl._apply(fn)
        return self

    def forward(self, x, *args, **kwargs):
        # 1. 基础归一化 (Int Infer 输入通常已经是 Tensor，这里做个保险)
        if x.max() > 5.0:
            x = x.float() / 255.0

        # 2. 推理 (此时 model 是 convert_fx 后的定点模型)
        outputs = self.model(x)
        proto = outputs[-1]

        # 3. Proto 形状修复逻辑
        if proto.dim() != 4:
            B, _, H_in, W_in = x.shape
            target_h = H_in // 4
            target_w = W_in // 4
            if proto.dim() == 2: proto = proto.unsqueeze(0)
            if proto.shape[-1] == 32: proto = proto.permute(0, 2, 1)
            try:
                proto = proto.view(B, 32, target_h, target_w)
            except RuntimeError:
                pass # 忽略错误，让外面报错

        # 4. 后处理拼接
        feats_box, feats_score, feats_mask = [], [], []
        layer_indices = [(0, 1, 2), (3, 4, 5), (6, 7, 8)]

        for idx_score, idx_box, idx_mask in layer_indices:
            score, box, mask = outputs[idx_score], outputs[idx_box], outputs[idx_mask]
            B, H, W, _ = score.shape
            feats_score.append(score.view(B, H*W, -1).permute(0, 2, 1)) 
            feats_box.append(box.view(B, H*W, -1).permute(0, 2, 1))     
            feats_mask.append(mask.view(B, H*W, -1).permute(0, 2, 1))   

        x_score = torch.cat(feats_score, dim=2) 
        x_box   = torch.cat(feats_box, dim=2)   
        x_mask  = torch.cat(feats_mask, dim=2)  

        # 生成 Anchors (需要伪造 fake_feats 用于计算 grid)
        fake_feats = [outputs[0].permute(0,3,1,2), outputs[3].permute(0,3,1,2), outputs[6].permute(0,3,1,2)]
        anchors, strides = make_anchors(fake_feats, self.stride, 0.5)
        anchors, strides = anchors.transpose(0, 1), strides.transpose(0, 1)

        pred_dist = self.dfl(x_box)
        pred_bboxes = dist2bbox(pred_dist, anchors.unsqueeze(0), xywh=True, dim=1) * strides
        pred_scores = x_score.sigmoid()
        
        preds = torch.cat((pred_bboxes, pred_scores, x_mask), dim=1) 
        return preds.contiguous(), proto.contiguous()

##############################################################################
# Next, we define the model convert pipeline to generate model for each stage.
##############################################################################
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}'  # tqdm bar format
LOGGING_NAME = 'ultralytics'
LOGGER = logging.getLogger(LOGGING_NAME)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))

def load_pretrain(model: nn.Module, model_path: str):
    state_dict = load_state_dict_from_url(
        MobileNet_V2_Weights.IMAGENET1K_V1.url, model_dir=model_path, progress=True
    )

    ignore_keys = []
    for k in state_dict:
        if "classifier" in k:
            ignore_keys.append(k)
    for k in ignore_keys:
        state_dict.pop(k)

    missing_keys, unexpected_keys = model.load_state_dict(
        state_dict, strict=False
    )
    assert len(missing_keys) == 2
    assert len(unexpected_keys) == 0

    return model

#-------------------------------------------------------------------------------------------------
def init_yolo_weights(model):
    """
    手动对模型进行初始化，模拟 Ultralytics 的从零训练策略。
    核心是把检测头的 Bias 初始化为 -4.6 (sigmoid后约为0.01)，
    防止训练初期 Loss 爆炸导致小目标特征丢失。
    """
    import math
    print("🛠️ Executing YOLO-style weight initialization...")
    
    for m in model.modules():
        t = type(m)
        # 1. 卷积层和 BN 层的基础初始化
        if t is nn.Conv2d:
            pass  # PyTorch 默认的 Kaiming 初始化通常够用了
        elif t is nn.BatchNorm2d:
            m.eps = 1e-3
            m.momentum = 0.03
        elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
            m.inplace = True

    # 2. 🚨 核心：初始化 Detect/Segment 头的 Bias 🚨
    # 你的模型结构是: model.head.segment_head
    # 我们需要找到 segment_head 里面的分类卷积层
    
    # 尝试访问 segment_head
    try:
        if hasattr(model, 'head') and hasattr(model.head, 'segment_head'):
            seg_head = model.head.segment_head
            nc = seg_head.nc # 获取类别数
            
            # 遍历寻找最后的分类卷积
            # X3Segment 内部结构可能比较深，用 named_modules 搜索最稳妥
            init_count = 0
            for name, m in seg_head.named_modules():
                # 特征：是卷积层，且输出通道数等于类别数 (nc)
                # 注意：有时候分类头输出是 nc，有时候是 nc * anchors，视具体实现而定
                # 在 Anchor-free 的 YOLOv8 中，分类头输出通道通常就是 nc
                if isinstance(m, nn.Conv2d) and m.out_channels == nc:
                    # 计算 bias: -log((1 - p) / p), p=0.01 -> bias ≈ -4.59
                    b = m.bias.view(1, -1)
                    b.data.fill_(-math.log((1 - 0.01) / 0.01))
                    m.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
                    init_count += 1
                    print(f"  ✅ Initialized bias for layer: {name}")
            
            if init_count == 0:
                print("  ⚠️ Warning: No classification layer found to initialize bias!")
            else:
                print(f"  ✅ Successfully initialized {init_count} detection layers.")
                
    except Exception as e:
        print(f"  ⚠️ Initialization skipped or failed: {e}")
#----------------------------------------------------------------------------------------------


def calibrate(
    calib_model,
    data_path,
    model_path,
    calib_batch_size,
    eval_batch_size,
    device,
    num_examples=float("inf"),
    march=March.BERNOULLI2,

):
    # calib_model = get_model_eager("calib", model_path, device, march, )
    # Please note that calibration need the model in eval mode
    # to make BatchNorm act properly.
    calib_model.eval()
    set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)

    trainset, valset, testset, names = get_dataset(data_path)
    train_dataset = YOLODataset(trainset, img_size=640)
    train_data_loader = DataLoader(train_dataset, 
                                    batch_size=calib_batch_size, 
                                    shuffle=True, 
                                    num_workers=8, 
                                    pin_memory=True, 
                                    collate_fn=collate_fn)

    eval_dataset = YOLODataset(valset, img_size=640)
    eval_data_loader = DataLoader(eval_dataset, 
                                    batch_size=eval_batch_size, 
                                    shuffle=True, 
                                    num_workers=8, 
                                    pin_memory=True, 
                                    collate_fn=collate_fn)

    # train_data_loader, eval_data_loader = prepare_data_loaders(
    #     data_path, calib_batch_size, eval_batch_size
    # )

    with torch.no_grad():
        cnt = 0
        # for image, target in train_data_loader:
        for i, batch in enumerate(train_data_loader):
            image= batch["img"].to(device)
            # image, target = image.to(device), target.to(device)
            calib_model(image)
            print(".", end="", flush=True)
            cnt += image.size(0)
            if cnt >= num_examples:
                break
        print()

    # Must set eval mode again before validation, because
    # set CALIBRATION state will make FakeQuantize in training mode.
    # calib_model.eval()
    # set_fake_quantize(calib_model, FakeQuantState.VALIDATION)

    # top1, top5 = evaluate(
    #     calib_model,
    #     eval_data_loader,
    #     device,
    # )
    # print(
    #     "Calibration: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
    #         top1.avg, top5.avg
    #     )
    # )

    torch.save(
        calib_model.state_dict(),
        os.path.join(model_path, "calib-checkpoint.ckpt"),
    )

    return calib_model


# def int_infer(
#     quantized_model,
#     data_path,
#     model_path,
#     eval_batch_size,
#     device,
#     march=March.BERNOULLI2,
# ):

#     _, eval_data_loader = prepare_data_loaders(
#         data_path, eval_batch_size, eval_batch_size
#     )

#     top1, top5 = evaluate(
#         quantized_model,
#         eval_data_loader,
#         device,
#     )
#     print(
#         "Quantized: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
#             top1.avg, top5.avg
#         )
#     )

#     return quantized_model

# def int_infer(
#     quantized_model,
#     data_path,       # 这里传入的是 yaml 路径
#     model_path,      # 这个参数在验证里暂时没用上，保留占位
#     eval_batch_size,
#     device,
#     march=March.BERNOULLI2,
# ):
#     print(f"\n [Int Infer] 开始定点模型推理验证...")
#     print(f"   Model Structure: Quantized GraphModule (FX)")
#     print(f"   Data Config: {data_path}")

#     quantized_model.eval()
#     model = FloatModelWrapper(quantized_model, nc=9) # 确保 nc 正确

#     device_str = str(device).split(':')[-1] if 'cuda' in str(device) else 'cpu'
    
#     args = dict(
#         model='yolov8n-seg.pt', # 这里的 model 只是个占位符，实际用的是下面赋值的 wrapper
#         data=data_path,         # data.yaml 的路径
#         imgsz=640,
#         batch=eval_batch_size,
#         split='val',
#         device=device_str,
#         plots=False,
#         save_json=False,
#         half=False # 定点模型不需要 fp16
#     )

#     validator = SegmentationValidator(args=args)
#     validator.model = model
#     validator.device = device
#     validator.nc = 9
#     validator.names = model.names

#     print(" 启动 Ultralytics Validator...")
#     stats = validator(model=model)

#     return quantized_model

def int_infer(
    quantized_model,
    data_path,       
    model_path,      
    eval_batch_size,
    device, # 这里的 device 参数我们会忽略，强制用 CPU
    march=March.BERNOULLI2,
):
    print(f"\n [Int Infer] 开始定点模型推理验证...")
    print(f"   Model Structure: Quantized GraphModule (FX)")
    print(f"   Data Config: {data_path}")
    print(f" [注意] 强制切换到 CPU 进行定点仿真 (避免 RTX4090 cuBLAS 兼容性问题)")

    # ================= [修改核心] =================
    # 1. 强制将模型移到 CPU
    target_device = torch.device("cpu")
    quantized_model = quantized_model.to(target_device)
    quantized_model.eval()
    
    # 2. 包装模型
    model = FloatModelWrapper(quantized_model, nc=9) 

    # 3. 准备参数 (告诉验证器用 CPU)
    args = dict(
        model='yolov8n-seg.pt', 
        data=data_path,         
        imgsz=640,
        batch=eval_batch_size, 
        split='val',
        device='cpu', # <--- 关键：告诉 Ultralytics 使用 CPU 数据加载
        plots=False,
        save_json=False,
        half=False 
    )

    # 4. 初始化验证器
    validator = SegmentationValidator(args=args)
    validator.model = model
    # 这一步很重要，覆盖验证器内部的 device
    validator.device = target_device 
    validator.nc = 9
    validator.names = model.names

    # 5. 运行验证
    print(" 启动 Ultralytics Validator (CPU Mode)...")
    stats = validator(model=model)
    
    return quantized_model


# def compile(
#     quantized_model,
#     data_path,
#     model_path,
#     compile_opt=0,
#     march=March.BAYES,
# ):
#     # It is recommended to do compile on cpu, because associated interfaces
#     # do not fully support cuda.
#     device = torch.device("cpu")
#     quantized_model = quantized_model.to(device)

#     _, eval_data_loader = prepare_data_loaders(data_path, 1, 1)

#     # We can generate random input data (in proper shape) for
#     # tracing and compiling and so on.
#     # Use real data in `perf_model` will get more accurate perf result.
#     example_input = next(iter(eval_data_loader))[0]

#     script_model = torch.jit.trace(quantized_model, example_input)
#     torch.jit.save(script_model, os.path.join(model_path, "int_model.pt"))

#     check_model(script_model.cpu(), [example_input], advice=1)

#     compile_model(
#         script_model,
#         [example_input],
#         hbm=os.path.join(model_path, "model.hbm"),
#         input_source="pyramid",
#         opt=compile_opt,
#     )

#     perf_model(
#         script_model,
#         [example_input],
#         out_dir=os.path.join(model_path, "perf_out"),
#         input_source="pyramid",
#         opt=compile_opt,
#         layer_details=True,
#     )

#     visualize_model(
#         script_model,
#         [example_input],
#         save_path=os.path.join(model_path, "model.svg"),
#         show=False,
#     )

#     return script_model
def compile(
    quantized_model,
    data_path,
    model_path,
    compile_opt=0, # 这个参数保留，但我们在下面会无视它
    march=March.BERNOULLI2,
    save_path=None,
):
    # 1. 路径处理
    if save_path is None:
        real_save_path = model_path
    else:
        real_save_path = save_path

    if not os.path.exists(real_save_path):
        os.makedirs(real_save_path, exist_ok=True)
    print(f"Compilation results will be saved to: {real_save_path}")

    # 2. 准备模型
    device = torch.device("cpu")
    quantized_model = quantized_model.to(device)
    quantized_model.eval()

    # 3. 准备数据
    _, val_path, _, _ = get_dataset(data_path)
    dataset = YOLODataset(val_path, img_size=640, augment=False)
    data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, collate_fn=collate_fn)
    batch = next(iter(data_loader))
    example_input = batch['img'].to(device)
    print(f"Compile Input Shape: {example_input.shape}")
    
    print("📦 Wrapping model with InversePermuteWrapper (NHWC -> NCHW)...")
    quantized_model = InversePermuteWrapper(quantized_model)

    # 4. Trace 模型
    script_model = torch.jit.trace(quantized_model, example_input)
    torch.jit.save(script_model, os.path.join(real_save_path, "int_model.pt"))

    # 5. Check (建议保留，可以看到算子支持情况)
    check_model(script_model.cpu(), [example_input], advice=1)

    print("🔥 FORCE OVERRIDE: Using O1 optimization and ['pyramid'] list format")

    # 6. Compile Model [关键修改点]
    compile_model(
        script_model,
        [example_input],
        hbm=os.path.join(real_save_path, "model.hbm"),
        
        # [修改 1] 必须是列表格式
        input_source=["pyramid"], 
        output_layout="NHWC",
        # [修改 2] 强制使用 "1" (O1)，避免 J3 编译器在 O3 下崩溃
        # 即使你外面传了 3，这里也强制覆盖
        opt="0", 
    )

    # 7. Perf Model
    perf_model(
        script_model,
        [example_input],
        out_dir=os.path.join(real_save_path, "perf_out"),
        input_source=["pyramid"], # 保持一致
        opt="0",                  # 保持一致
        layer_details=True,
    )

    # 8. Visualize
    visualize_model(
        script_model,
        [example_input],
        save_path=os.path.join(real_save_path, "model.svg"),
        show=False,
    )

    print("\n" + "="*20 + " Model Output Info " + "="*20)
    for i, output in enumerate(script_model.graph.outputs()):
        print(f"Index {i}: Name={output.debugName()} | Type={output.type()}")
    print("="*60 + "\n")

    print(f"J3 Compilation finished! Saved to {real_save_path}")
    return script_model

# def compile(
#     quantized_model,
#     data_path,
#     model_path,
#     compile_opt=3,
#     march=March.BERNOULLI2,
#     save_path=None,
# ):
#     # 1. 路径处理
#     if save_path is None:
#         real_save_path = model_path
#     else:
#         real_save_path = save_path

#     if not os.path.exists(real_save_path):
#         os.makedirs(real_save_path, exist_ok=True)
#     print(f"Compilation results will be saved to: {real_save_path}")

#     # 2. 准备模型和数据
#     device = torch.device("cpu")
#     quantized_model = quantized_model.to(device)
#     quantized_model.eval()

#     _, val_path, _, _ = get_dataset(data_path)
#     dataset = YOLODataset(val_path, img_size=640, augment=False)
#     data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, collate_fn=collate_fn)

#     batch = next(iter(data_loader))
#     example_input = batch['img'].to(device)
#     print(f"Compile Input Shape: {example_input.shape}")


#     # # =========================================================================
#     # # # >>>>>> 新增核心代码：插入 YUV -> RGB 转换算子 >>>>>>  太傻逼了，只有J5支持这个算子
#     # # =========================================================================
#     # print("\n[Auto-Insert] Inserting centered_yuv2rgb for Pyramid compatibility...")
    
#     # # 1. 先把 Quantized 模型转成 Symbolic Traced 格式 (方便改图)
#     # from torch.fx import symbolic_trace
#     # # traced_model = symbolic_trace(quantized_model)
#     # traced_model = quantized_model

#     # # 2. 寻找输入的 QuantStub 节点 (通常是第一个 call_module)
#     # inserted = False
#     # for node in traced_model.graph.nodes:
#     #     # 你的 QuantStub 名字可能叫 'quant'，也可能叫 'quant_stub'，根据你的模型定义调整
#     #     # 这里假设它包含 'quant' 字符串
#     #     if node.op == 'call_module' and 'quant' in node.name: 
#     #         print(f"  Found QuantStub node: {node.name}")
            
#     #         with traced_model.graph.inserting_after(node):
#     #             # 3. 插入转换算子
#     #             # 参数必须和你 RGB 训练时的预处理一致
#     #             # YOLOv8 默认是: 0-255 -> /255 -> 0.0-1.0
#     #             # 所以 mean=0, std=255, scale=1/255
#     #             new_node = traced_model.graph.call_function(
#     #                 centered_yuv2rgb, 
#     #                 args=(node,), 
#     #                 kwargs={
#     #                     "swing": "full", 
#     #                     "mean": [0., 0., 0.],      # Mean = 0
#     #                     "std": [255., 255., 255.], # Std = 255
#     #                     "q_scale": 1.0 / 255.0     # Quant Scale = 1/255
#     #                 }
#     #             )
                
#     #             # 4. 让后面的节点连接到新节点
#     #             node.replace_all_uses_with(new_node)
#     #             # 修正新节点的输入指向原节点 (replace_all_uses_with 会把新节点的输入也改了，需修正)
#     #             new_node.args = (node,) 
                
#     #             # 5. 【重要】强制修改原 QuantStub 的 scale 为 1.0
#     #             # 因为 centered_yuv2rgb 内部已经做了 /255 的量化，
#     #             # 原来的 QuantStub 如果再做 /255 就重复了。
#     #             if hasattr(traced_model, node.target):
#     #                 print(f"  Resetting scale of {node.target} to 1.0")
#     #                 getattr(traced_model, node.target).scale.fill_(1.0)
            
#     #         inserted = True
#     #         break # 只插第一个入口
    
#     # if not inserted:
#     #     print("[WARNING] Could not find QuantStub node! YUV conversion NOT inserted.")
    
#     # traced_model.recompile()
#     # # 用修改后的模型替换原来的模型
#     # script_model = torch.jit.trace(quantized_model, example_input)

#     # 3. Trace 模型
#     script_model = torch.jit.trace(quantized_model, example_input)

#     # =========================================================================
#     # 【新增】: 自定义输出名称 (Renaming Outputs)
#     # =========================================================================
#     # 警告：这里的顺序必须和你模型 forward 返回的 tuple 顺序完全一致！
#     # 通常 YOLOv8-Seg 的顺序是：
#     # Stride 8 (Box, Cls, MaskCoef) -> Stride 16 (...) -> Stride 32 (...) -> Proto
#     # 你可以根据下面的打印日志里的 Shape 来校验这个顺序对不对
    
#     custom_names = [
#         "bbox_8",   "cls_8",   "mask_coef_8",   # 对应 output[0], [1], [2]
#         "bbox_16",  "cls_16",  "mask_coef_16",  # 对应 output[3], [4], [5]
#         "bbox_32",  "cls_32",  "mask_coef_32",  # 对应 output[6], [7], [8]
#         "proto_mask"                            # 对应 output[9]
#     ]

#     print("\n" + "="*20 + " [Start Renaming Outputs] " + "="*20)
#     outputs = list(script_model.graph.outputs())
    
#     if len(outputs) != len(custom_names):
#         print(f"[ERROR] Name count mismatch! Model has {len(outputs)} outputs, provided {len(custom_names)} names.")
#     else:
#         for i, node in enumerate(outputs):
#             old_name = node.debugName()
#             new_name = custom_names[i]
            
#             # 修改名称
#             node.setDebugName(new_name)
            
#             # 打印改名结果，方便确认
#             print(f"Output [{i}]: {old_name} -> {new_name}")
#             # print(f"  Type info: {node.type()}") # 如果需要看详细 shape 可以把这行解开
#     print("="*60 + "\n")
#     # =========================================================================

#     # 4. 保存修改名字后的 PT 文件
#     torch.jit.save(script_model, os.path.join(real_save_path, "int_model.pt"))

#     # 5. Check Model
#     check_model(script_model.cpu(), [example_input], advice=1)

#     # 6. Compile Model
#     # 注意：为了配合你 C++ 里的 NV12 逻辑，这里建议用 "pyramid"
#     # 如果你必须用 DDR 输入，请改回 "ddr"
#     compile_model(
#         script_model,
#         [example_input],
#         hbm=os.path.join(real_save_path, "model.hbm"),
#         input_source="Pyramid",  # <--- 建议：改为 pyramid 适配 NV12 (Type 2)
#         # input_source=["resizer"], 
#         # 建议显式指定 layout 为 NHWC，这是硬件最高效的格式
#         # input_layout="NHWC",
#         # output_layout="NHWC",
#         opt=compile_opt,
#     )

#     # 7. Perf Model
#     perf_model(
#         script_model,
#         [example_input],
#         out_dir=os.path.join(real_save_path, "perf_out"),
#         input_source="Pyramid", # <--- 保持一致
#         opt=compile_opt,
#         layer_details=True,
#     )

#     # 8. Visualize Model (生成的 SVG 里会显示你改的新名字)
#     visualize_model(
#         script_model,
#         [example_input],
#         save_path=os.path.join(real_save_path, "model.svg"),
#         show=False,
#     )

#     # 9. 打印最终信息
#     print("\n" + "="*20 + " Final Model Output Info " + "="*20)
#     for i, output in enumerate(script_model.graph.outputs()):
#         # 这里打印出来的应该是你刚才修改后的名字
#         print(f"Index {i}: Name={output.debugName()}")
#     print("="*60 + "\n")

#     print(f"J3 Compilation finished! Saved to {real_save_path}")
#     return script_model


def prepare_data_loaders(
    data_path: str, train_batch_size: int, eval_batch_size: int
) -> Tuple[data.DataLoader, data.DataLoader]:
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_dataset = CIFAR10(
        data_path,
        True,
        transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        ),
        download=True,
    )

    eval_dataset = CIFAR10(
        data_path,
        False,
        transforms.Compose(
            [
                transforms.ToTensor(),
                normalize,
            ]
        ),
        download=True,
    )

    train_data_loader = data.DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        sampler=data.RandomSampler(train_dataset),
        num_workers=8,
        pin_memory=True,
    )

    eval_data_loader = data.DataLoader(
        eval_dataset,
        batch_size=eval_batch_size,
        sampler=data.SequentialSampler(eval_dataset),
        num_workers=8,
        pin_memory=True,
    )

    return train_data_loader, eval_data_loader




class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name: str, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


def accuracy(output: Tensor, target: Tensor, topk=(1,)) -> List[Tensor]:
    """Computes the accuracy over the k top predictions for the specified
    values of k
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].float().sum()
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


# def train_one_epoch(
#     model: nn.Module,
#     optimizer: torch.optim.Optimizer,
#     scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
#     data_loader: data.DataLoader,
#     device: torch.device,
# ) -> None:
#     # top1 = AverageMeter("Acc@1", ":6.3f")
#     # top5 = AverageMeter("Acc@5", ":6.3f")
#     loss_names = ["Box Loss", "Seg Loss", "Cls Loss", "DFL Loss"]
#     loss_meters = {name: AverageMeter(name, ":1.5f") for name in loss_names}
#     avgloss = AverageMeter("Total Loss", ":1.5f")
#     scaler = GradScaler()
#     PRINT_FREQ = 10 # 打印频率

#     for i, batch in enumerate(data_loader):
#         image= batch["img"].to(device)
#         batch['labels']["targets_out"] = batch['labels']["targets_out"].to(device)
#         batch['labels']["masks_out"] = batch['labels']["masks_out"].to(device)

#         with autocast():
#             preds = model(image)
#             loss, loss_items= criterion(preds, batch)
#         optimizer.zero_grad()
#         scaler.scale(loss).backward()

#         # ================== [新增/修改 START] ==================
#         # 1. 先把梯度从 scaler 里解包出来 (必须做！)
#         scaler.unscale_(optimizer)
        
#         # 2. 强制裁剪梯度，最大范数设为 10.0 (YOLO常用值)
#         # 这行代码能把那些试图飞到天上去的梯度拽回来
#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        
#         # 3. 更新参数
#         scaler.step(optimizer)
#         scaler.update()
#         # ================== [新增/修改 END] ==================

#         scaler.step(optimizer)
#         scaler.update() 
#         if scheduler is not None:
#             scheduler.step()

#         loss_items = loss_items.detach().cpu().view(-1)
#         for k, name in enumerate(loss_names):
#             if k < len(loss_items):
#                 loss_meters[name].update(loss_items[k].item(), image.size(0))

#         avgloss.update(loss.item(), image.size(0))
#         if i % PRINT_FREQ == 0 or i == len(data_loader) - 1:
#             log_str = f"Epoch Batch [{i}/{len(data_loader)-1}] | Total Loss: {avgloss.avg:.4f} | "
#             # 拼接各项分项 Loss 的平均值
#             log_parts = [f"{name}: {meter.avg:.4f}" for name, meter in loss_meters.items()]
#             print("\n" + log_str + " | ".join(log_parts))
        
#         print(".", end="", flush=True)

#     print(f"\n--- Epoch Summary: Total Avg Loss: {avgloss.avg:.4f} ---")

def train_one_epoch(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
    data_loader: data.DataLoader,
    device: torch.device,
) -> None:
    loss_names = ["Box Loss", "Seg Loss", "Cls Loss", "DFL Loss"]
    loss_meters = {name: AverageMeter(name, ":1.5f") for name in loss_names}
    avgloss = AverageMeter("Total Loss", ":1.5f")
    
    # [修改 1] 注释掉 Scaler，使用纯 float32 训练
    # scaler = GradScaler() 
    PRINT_FREQ = 10 

    model.train() # 确保模型处于训练模式

    for i, batch in enumerate(data_loader):
        # === DEBUG START: 可视化监督信号 ===
        if i == 0: # 只看第一个 Batch  
            # 1. 定义你的类别名称和对应的颜色 (BGR格式: [Blue, Green, Red])
            CLASS_NAMES = [
                "Barrier",      # 0
                "Bollard",      # 1
                "Driving_area", # 2
                "N_Obstacle",   # 3
                "Obstacle",     # 4
                "Parking_lock", # 5
                "Road_cone",    # 6
                "curb",         # 7
                "mask"          # 8
            ]
            
            # 自定义高对比度颜色表
            CUSTOM_COLORS = [
                [255, 0, 0],    # 0: Barrier -> 蓝色
                [0, 255, 0],    # 1: Bollard -> 绿色
                [0, 255, 255],  # 2: Driving_area -> 黄色 (显眼)
                [0, 0, 255],    # 3: N_Obstacle -> 红色 (警示)
                [0, 140, 255],  # 4: Obstacle -> 橙色
                [255, 0, 255],  # 5: Parking_lock -> 紫色
                [0, 215, 255],  # 6: Road_cone -> 金色
                [180, 105, 255],# 7: curb -> 粉紫色
                [128, 128, 128] # 8: mask -> 灰色
            ]

            # 2. 取第一张图
            debug_img = batch['img'][0].cpu().permute(1, 2, 0).numpy()
            debug_img = np.ascontiguousarray(debug_img * 255, dtype=np.uint8)
            debug_img = cv2.cvtColor(debug_img, cv2.COLOR_RGB2BGR)
            H, W = debug_img.shape[:2]
            
            # 3. 找到属于第一张图的数据
            idx = (batch['batch_idx'] == 0).squeeze()
            if idx.ndim == 0: idx = idx.unsqueeze(0)
                
            current_masks = batch['masks'][idx] 
            current_cls   = batch['cls'][idx]   
            
            print(f"DEBUG: Batch[0] 包含 {len(current_masks)} 个 Mask")
            
            if len(current_masks) > 0:
                mask_layer = np.zeros_like(debug_img)
                
                for k, m in enumerate(current_masks):
                    # 获取类别 ID
                    class_id = int(current_cls[k].item())
                    
                    # 获取颜色 (防止越界)
                    if class_id < len(CUSTOM_COLORS):
                        color = CUSTOM_COLORS[class_id]
                        name = CLASS_NAMES[class_id]
                    else:
                        color = [255, 255, 255] # 未知类别用白色
                        name = f"ID:{class_id}"

                    # 处理 Mask
                    m_np = m.cpu().numpy().astype(np.uint8)
                    m_resized = cv2.resize(m_np, (W, H), interpolation=cv2.INTER_NEAREST)
                    
                    # 涂色
                    mask_layer[m_resized > 0] = color
                    
                    # 在重心位置写上类别名称
                    M = cv2.moments(m_resized)
                    if M["m00"] != 0:
                        cX = int(M["m10"] / M["m00"])
                        cY = int(M["m01"] / M["m00"])
                        #以此坐标为中心写字 (白色字，黑色描边，防止看不清)
                        cv2.putText(debug_img, name, (cX, cY), cv2.FONT_HERSHEY_SIMPLEX, 
                                    0.6, (0, 0, 0), 3) # 黑边
                        cv2.putText(debug_img, name, (cX, cY), cv2.FONT_HERSHEY_SIMPLEX, 
                                    0.6, (255, 255, 255), 1) # 白字

                # 4. 叠加显示
                debug_vis = cv2.addWeighted(debug_img, 0.6, mask_layer, 0.4, 0)
                
                cv2.imwrite("debug_supervision_color.jpg", debug_vis)
                print("✅ 已保存 debug_supervision_color.jpg，请查看！")
                
                if mask_layer.sum() == 0:
                    print("❌❌❌ 严重错误：Mask 层全是黑的！监督数据有问题！❌❌❌")
            else:
                print("⚠️ 本图没有 Mask 目标")
            # === DEBUG END ===
        image = batch["img"].to(device)
        #-----------------------------------------------------------------------------------------------------
        # batch['labels']["targets_out"] = batch['labels']["targets_out"].to(device)
        # batch['labels']["masks_out"] = batch['labels']["masks_out"].to(device)
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.to(device)
        #-----------------------------------------------------------------------------------------------------
        # [修改 3] 去掉 autocast，直接前向传播
        # with autocast():
        preds = model(image)
        loss, loss_items = criterion(preds, batch)
        
        # 这里的 loss_items 包含了分项 loss，用于打印
        
        optimizer.zero_grad()
        
        # [修改 4] 使用原生 backward，不使用 scaler
        loss.backward()
        
        # [修改 5] ★★★ 核心救命代码：梯度裁剪 ★★★
        # 限制梯度最大范数为 10.0，防止 inf/nan
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        
        # [修改 6] 原生 step
        optimizer.step()
        # scaler.update() <--- 注释掉

        # [修改 7] 学习率更新位置 (按 Batch 更新策略，如 OneCycleLR)
        # 如果你是 StepLR/CosineAnnealingLR (按Epoch)，请把这行移到 for 循环外面！
        # 但既然你原本写在这里，我先保留，请确认你的 Scheduler 类型。
        if scheduler is not None:
            scheduler.step()

        # 日志记录 (保持不变)
        loss_items = loss_items.detach().cpu().view(-1)
        for k, name in enumerate(loss_names):
            if k < len(loss_items):
                # 防止 loss_items 里也有 inf 污染统计
                val = loss_items[k].item()
                if not (math.isinf(val) or math.isnan(val)):
                    loss_meters[name].update(val, image.size(0))

        if not (math.isinf(loss.item()) or math.isnan(loss.item())):
            avgloss.update(loss.item(), image.size(0))

        if i % PRINT_FREQ == 0 or i == len(data_loader) - 1:
            log_str = f"Epoch Batch [{i}/{len(data_loader)-1}] | Total Loss: {avgloss.avg:.4f} | "
            log_parts = [f"{name}: {meter.avg:.4f}" for name, meter in loss_meters.items()]
            print("\n" + log_str + " | ".join(log_parts))
        
        print(".", end="", flush=True)

    print(f"\n--- Epoch Summary: Total Avg Loss: {avgloss.avg:.4f} ---")


def evaluate(
    model: nn.Module, data_loader: data.DataLoader, device: torch.device
) -> Tuple[AverageMeter, AverageMeter]:
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")

    with torch.no_grad():
        for image, target in data_loader:
            image, target = image.to(device), target.to(device)
            output = model(image)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1, image.size(0))
            top5.update(acc5, image.size(0))
            print(".", end="", flush=True)
        print()

    return top1, top5


# Float and qat share the same training procedure.
def train(
    args,
    model,
    data_path: str,
    model_path: str,
    train_batch_size: int,
    eval_batch_size: int,
    epoch_num: int,
    device: torch.device,
    optim_config: Callable,
    stage: str,
    march=March.BERNOULLI2,
):
    # model = get_model_eager(stage, model_path, device, march)
    trainset, valset, testset, names = get_dataset(data_path)
    train_dataset = YOLODataset(trainset, img_size=640, augment=True)
    train_data_loader = DataLoader(train_dataset, 
                                    batch_size=train_batch_size, 
                                    shuffle=True, 
                                    num_workers=8, 
                                    pin_memory=True, 
                                    collate_fn=collate_fn)

    test_dataset = YOLODataset(testset, img_size=640)
    test_data_loader = DataLoader(test_dataset, 
                                    batch_size=eval_batch_size, 
                                    shuffle=True, 
                                    num_workers=8, 
                                    pin_memory=True, 
                                    collate_fn=collate_fn)

    eval_dataset = YOLODataset(valset, img_size=640, augment=False)
    eval_data_loader = DataLoader(eval_dataset, 
                                    batch_size=eval_batch_size, 
                                    shuffle=True, 
                                    num_workers=8, 
                                    pin_memory=True, 
                                    collate_fn=collate_fn)

    optimizer, scheduler = optim_config(model, args)
    # ================= [新增 1] 初始化最优 Loss 记录 =================
    best_loss = float('inf') # 初始设为无穷大
    best_epoch = -1
    # ==============================================================
    for nepoch in range(epoch_num):
        model.train()
        if stage == "qat":
            set_fake_quantize(model, FakeQuantState.QAT)
        
        print(f"\nEpoch {nepoch}/{epoch_num-1} Training:")
        
        # 1. 训练一个 Epoch
        train_one_epoch(
            model,
            optimizer,
            scheduler,
            train_data_loader,
            device,
        )

        print(f"Epoch {nepoch}/{epoch_num-1} Validating:")
        current_val_loss = validate_loss(model, eval_data_loader, device)

        last_ckpt_path = os.path.join(model_path, f"{stage}-last-checkpoint.ckpt")
        torch.save(model.state_dict(), last_ckpt_path)
        
        if current_val_loss < best_loss:
            print(f"🔥 Found Better Model! Loss improved from {best_loss:.5f} to {current_val_loss:.5f}")
            best_loss = current_val_loss
            best_epoch = nepoch
            
            best_ckpt_path = os.path.join(model_path, f"{stage}-best-checkpoint.ckpt")
            torch.save(model.state_dict(), best_ckpt_path)
        else:
            print(f"  (Current Loss {current_val_loss:.5f} did not improve best {best_loss:.5f} at epoch {best_epoch})")


    print(f"\nTraining Finished. Best Validation Loss: {best_loss:.5f} at Epoch {best_epoch}")
    print(f"Loading best weights from {stage}-best-checkpoint.ckpt for return...")
    best_ckpt_path = os.path.join(model_path, f"{stage}-best-checkpoint.ckpt")
    if os.path.exists(best_ckpt_path):
        model.load_state_dict(torch.load(best_ckpt_path))
    
    return model


    # best_acc = 0

    # for nepoch in range(epoch_num):
    #     # Training/Eval state must be setted correctly
    #     # before `set_fake_quantize`
    #     model.train()
    #     if stage == "qat":
    #         set_fake_quantize(model, FakeQuantState.QAT)

    #     train_one_epoch(
    #         model,
    #         optimizer,
    #         scheduler,
    #         train_data_loader,
    #         device,
    #     )

    #     torch.save(
    #         model.state_dict(),
    #         os.path.join(model_path, "{}-checkpoint.ckpt".format(stage)),
    #     )

    # return model

#---------------------------------------------------------------------------------------------

# def train(
#     args,
#     model,
#     data_path: str,
#     model_path: str,
#     train_batch_size: int,
#     eval_batch_size: int,
#     epoch_num: int,
#     device: torch.device,
#     optim_config: Callable,
#     stage: str,
#     march=March.BERNOULLI2,
# ):
#     # dataset init
#     trainset, valset, testset, names = get_dataset(data_path)
#     # 注意：这里 YOLODataset 会使用你之前改好的默认 hyp (mosaic=1.0)
#     train_dataset = YOLODataset(trainset, img_size=640) 
#     train_data_loader = DataLoader(train_dataset, 
#                                     batch_size=train_batch_size, 
#                                     shuffle=True, 
#                                     num_workers=8, 
#                                     pin_memory=True, 
#                                     collate_fn=collate_fn)

#     # ... (验证集 loader 代码保持不变) ...
#     test_dataset = YOLODataset(testset, img_size=640)
#     test_data_loader = DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=True, num_workers=8, pin_memory=True, collate_fn=collate_fn)
#     eval_dataset = YOLODataset(valset, img_size=640)
#     eval_data_loader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=True, num_workers=8, pin_memory=True, collate_fn=collate_fn)

#     optimizer, scheduler = optim_config(model, args)

#     best_acc = 0
    
#     # 定义最后多少轮关闭 Mosaic
#     close_mosaic_epochs = 10 

#     for nepoch in range(epoch_num):
#         model.train()
        
#         # ============================================================
#         # 🔥 [新增] Close Mosaic 策略: 最后 10 轮关闭马赛克
#         # ============================================================
#         # 判断：如果剩余轮数 <= 10 且 Mosaic 当前还是开启状态
#         if (epoch_num - nepoch) <= close_mosaic_epochs and train_dataset.hyp.get('mosaic', 0) > 0:
#             print(f"\n🚀 Epoch {nepoch}: Closing Mosaic augmentation for final fine-tuning!")
#             # 强制将 Dataset 中的 mosaic 概率设为 0
#             train_dataset.hyp['mosaic'] = 0.0
#             train_dataset.hyp['mixup'] = 0.0 # 如果有 mixup 也顺便关掉
#         # ============================================================

#         if stage == "qat":
#             set_fake_quantize(model, FakeQuantState.QAT)

#         train_one_epoch(
#             model,
#             optimizer,
#             scheduler,
#             train_data_loader,
#             device,
#         )

#         # ... (后续保存模型代码保持不变) ...
#         torch.save(
#             model.state_dict(),
#             os.path.join(model_path, "{}-checkpoint.ckpt".format(stage)),
#         )

#     return model



def main(
    args,
    model,
    stage: str,
    data_path: str,
    model_path: str,
    train_batch_size: int,
    eval_batch_size: int,
    epoch_num: int,
    device_id: int = 4,
    march: str = March.BERNOULLI2,
    compile_opt: int = 3,
    save_path: str = None
):
    assert stage in ("float", "calib", "qat", "int_infer", "compile")


    device = torch.device(
        "cuda:{}".format(device_id) if device_id >= 0 else "cpu"
    )

    if not os.path.exists(model_path):
        os.makedirs(model_path, exist_ok=True)

    def float_optim_config(model: nn.Module, args):
        # This is an example to illustrate the usage of QAT training tool, so
        # we do not fine tune the training hyper params to get optimized
        # float model accuracy.
        lr0 = 0.01             # initial LR
        lrf = 0.01             # final lr = lr0 * lrf
        momentum = 0.937
        weight_decay = 0.0005
        warmup_epochs = 3.0
        warmup_momentum = 0.8
        warmup_bias_lr = 0.1
        epochs = args.epoch_num
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=lr0,
            momentum=momentum,
            weight_decay=weight_decay,
            nesterov=True
        )
        def lr_lambda(epoch):
                # 1. Warmup
                if epoch < warmup_epochs:
                    # linear warmup
                    return (epoch / warmup_epochs)

                # 2. Cosine decay from lr0 to lr0*lrf
                return ((1 - math.cos((epoch - warmup_epochs) / (epochs - warmup_epochs) * math.pi)) / 2) * (1 - lrf) + lrf
                
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        return optimizer, scheduler
    

    def qat_optim_config(model: nn.Module, args):
        # QAT training is targeted at fine tuning model params to match the
        # numerical quantization, so the learning rate should not be too large.
        optimizer = torch.optim.SGD(
            model.parameters(), lr=0.0001, weight_decay=2e-4
        )

        return optimizer, None

    default_epoch_num = {
        "float": 30,
        "qat": 10,
    }

    if stage in ("float", "qat"):
        if epoch_num is None:
            epoch_num = default_epoch_num[stage]

        train(
            args,
            model,
            data_path,
            model_path,
            train_batch_size,
            eval_batch_size,
            epoch_num,
            device,
            float_optim_config if stage == "float" else qat_optim_config,
            stage,
            march,
        )

    elif stage == "calib":

        calibrate(
            model,
            data_path,
            model_path,
            train_batch_size,
            eval_batch_size,
            device,
            march=march,
        )

    elif stage == "int_infer":

        int_infer(model, data_path,model_path,eval_batch_size,
            device,
            march=march,
        )

    else:

        compile(
            model,
            data_path,
            model_path,
            compile_opt,
            march=march,
            save_path=save_path        # 新增
        )

def get_args():
    parser = argparse.ArgumentParser(description="Run mobilnet example.")
    parser.add_argument(
        "--stage",
        type=str,
        choices=("float", "calib", "qat", "int_infer", "compile"),
        help=(
            "Pipeline stage, must be executed in following order: "
            "float -> calib(optional) -> qat(optional) -> int_infer -> compile"
        ),
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="data",
        help="Path to the cifar-10 dataset",
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default="model/mobilenetv2",
        help="Where to save the model and other results",
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=256,
        help="Batch size for training",
    )
    parser.add_argument(
        "--eval_batch_size",
        type=int,
        default=256,
        help="Batch size for evaluation",
    )
    parser.add_argument(
        "--epoch_num",
        type=int,
        default=None,
        help=(
            "Rewrite the default training epoch number, pass 0 to skip "
            "training and only do evaluation (in stage 'float' or 'qat')"
        ),
    )
    parser.add_argument(
        "--device_id",
        type=int,
        default=0,
        help="Specify which device to use, pass a negative value to use cpu",
    )
    parser.add_argument(
        "--opt",
        type=str,
        choices=["0", "1", "2", "3", "ddr", "fast", "balance"],
        default=0,
        help="Specity optimization level for compilation",
    )
    parser.add_argument(
        "--march",
        type=str,
        default="BERNOULLI2",
        help="Specify march for quantization",
    )
    args = parser.parse_args()
    return args


def get_dataset(data_path):
    """
    Get train, val path from data dict if it exists. Returns None if data format is not recognized.
    """
    with open(data_path, "r", encoding="utf-8") as f:
        data = yaml.safe_load(f)

    return data['train'], data['val'], data['test'], data['names']

# class YOLODataset(Dataset):
#     def __init__(self, img_dir, img_size=640, augment=False):
#         """
#         img_dir:
#             ├── train
#             │    ├── images
#             │    └── labels
#         """
#         self.img_dir = Path(img_dir)
#         self.lbl_dir = Path(img_dir).parent.parent / "labels" / Path(img_dir).parts[-1]
#         self.files = list(self.img_dir.glob("*.png"))
#         self.img_size = img_size
#         self.augment = augment

#         print(f"Loaded {len(self.files)} images from {img_dir}")

#     def __len__(self):
#         return len(self.files)

#     def __getitem__(self, index):

#         while True:
#             img_path = self.files[index]
#             label_path = self.lbl_dir / (img_path.stem + ".txt")

#             # 1. load image
#             img = cv2.imread(str(img_path))
#             if img is None:
#                 index = (index + 1) % len(self.files)
#                 continue
#             img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

#             # 2. load label
#             segments = []
#             classes = []
#             bboxes = []
#             MAX_POINTS = 1000
#             if label_path.exists():
#                 for line in open(label_path):
#                     items= list(map(float, line.split()))
#                     cls = items[0]
#                     segment = np.array(items[1:], dtype=np.float32).reshape(-1, 2)
#                     bbox = self.segment2bbox(segment, self.img_size, self.img_size)
#                     if len(segment) > MAX_POINTS:
#                         indices = np.linspace(0, len(segment) - 1, MAX_POINTS).astype(int)
#                         segment = segment[indices]
                        
#                     elif len(segment) < MAX_POINTS:
#                         padding = np.zeros((MAX_POINTS - len(segment), 2), dtype=np.float32)
#                         segment = np.concatenate([segment, padding], axis=0)
#                     classes.append(cls)
#                     bboxes.append(bbox)
#                     segments.append(segment)

#             segments = np.stack(segments, axis=0)
#             classes = np.stack(classes, axis=0)
#             bboxes = np.stack(bboxes, axis=0)
#             # 3. letterbox resize
#             img, r, (dw, dh) = self.letterbox(img, self.img_size)
#             masks, sorted_idx = self.polygons2masks_overlap((self.img_size, self.img_size), segments, downsample_ratio=4)
#             masks = masks[None]
#             segments = segments[sorted_idx]
#             classes = classes[sorted_idx]
#             bboxes = bboxes[sorted_idx]
#             # 4. convert xywh → absolute px (训练时通常还是 normalized，不转换也行)
#             # if len(labels):
#             #     labels[:, 1] = labels[:, 1] * img.shape[1]   # x
#             #     labels[:, 2] = labels[:, 2] * img.shape[0]   # y
#             #     labels[:, 3] = labels[:, 3] * img.shape[1]   # w
#             #     labels[:, 4] = labels[:, 4] * img.shape[0]   # h

#             # 5. to tensor
#             img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0

#             return {
#                 "img": img,
#                 "labels": {
#                     'masks': torch.from_numpy(masks),
#                     'segments': torch.from_numpy(segments),
#                     'classes': torch.from_numpy(classes),
#                     'bboxes': torch.from_numpy(bboxes)
#                 }
#             }
        
#     def letterbox(self, im, new_shape=640, color=(114, 114, 114)):
#         """Resize and pad image while meeting stride-multiple constraints"""
#         shape = im.shape[:2]  # h, w
#         r = new_shape / max(shape)  # resize ratio
#         new_unpad = int(shape[1] * r), int(shape[0] * r)
#         im_resized = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
        
#         dw = new_shape - new_unpad[0]
#         dh = new_shape - new_unpad[1]
#         dw //= 2
#         dh //= 2
        
#         im_out = cv2.copyMakeBorder(im_resized, dh, dh, dw, dw, cv2.BORDER_CONSTANT, value=color)
#         return im_out, r, (dw, dh)

#     def polygons2masks_overlap(self, imgsz, segments, downsample_ratio=1):
#         """Return a (640, 640) overlap mask."""
#         segments[:, :, 0] *= imgsz[0]
#         segments[:, :, 1] *= imgsz[1]
#         masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
#                         dtype=np.int32 if len(segments) > 255 else np.uint8)
#         areas = []
#         ms = []
#         for si in range(len(segments)):
#             mask = self.polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
#             ms.append(mask)
#             areas.append(mask.sum())
#         areas = np.asarray(areas)
#         index = np.argsort(-areas)
#         ms = np.array(ms)[index]
#         for i in range(len(segments)):
#             mask = ms[i] * (i + 1)
#             masks = masks + mask
#             masks = np.clip(masks, a_min=0, a_max=i + 1)
#         return masks, index

#     def polygon2mask(self, imgsz, polygons, color=1, downsample_ratio=1):
#         """
#         Args:
#             imgsz (tuple): The image size.
#             polygons (np.ndarray): [N, M], N is the number of polygons, M is the number of points(Be divided by 2).
#             color (int): color
#             downsample_ratio (int): downsample ratio
#         """
#         mask = np.zeros(imgsz, dtype=np.uint8)
#         polygons = np.asarray(polygons)
#         polygons = polygons.astype(np.int32)
#         shape = polygons.shape
#         polygons = polygons.reshape(shape[0], -1, 2)
#         cv2.fillPoly(mask, polygons, color=color)
#         nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
#         # NOTE: fillPoly firstly then resize is trying the keep the same way
#         # of loss calculation when mask-ratio=1.
#         mask = cv2.resize(mask, (nw, nh))
#         return mask

#     def segment2bbox(self, segments, img_w, img_h):

#         min_xy = segments.min(axis=0)  # Shape: (N, 2) -> [xmin, ymin]
#         max_xy = segments.max(axis=0)  # Shape: (N, 2) -> [xmax, ymax]
        
#         # 2. 拆分坐标
#         xmin, ymin = min_xy[0], min_xy[1]
#         xmax, ymax = max_xy[0], max_xy[1]
        
#         # 3. 计算 xywh (绝对坐标)
#         width = xmax - xmin
#         height = ymax - ymin
#         x_center = xmin + width / 2
#         y_center = ymin + height / 2
        
#         bboxes = np.stack([
#             x_center,
#             y_center,
#             width,
#             height
#         ], axis=0)
        
#         return bboxes


# class YOLODataset(Dataset):
#     def __init__(self, img_dir, img_size=640, augment=True, hyp=None):
#         """
#         Args:
#             augment: 是否开启增强 (训练时True, 验证时False)
#             hyp: 超参数字典 (hsv_h, hsv_s, hsv_v, fliplr 等)
#         """
#         self.img_dir = Path(img_dir)
#         # 假设目录结构是 images/.. 和 labels/..
#         self.lbl_dir = self.img_dir.parent.parent / "labels" / self.img_dir.parts[-1]
        
#         # 1. 找到所有图片
#         self.img_files = sorted(list(self.img_dir.glob("*.jpg")) + list(self.img_dir.glob("*.png")))
#         self.img_size = img_size
#         self.augment = augment
        
#         # 默认超参数 (参考官方 default.yaml)
#         self.hyp = hyp if hyp else {
#             'hsv_h': 0.015, 'hsv_s': 0.7, 'hsv_v': 0.4, # HSV 增强幅度
#             'degrees': 0.0, 'translate': 0.1, 'scale': 0.5, # 几何增强 (这里暂不实现复杂的Affine)
#             'fliplr': 0.5,  # 左右翻转概率
#             'mask_ratio': 4 # Mask 下采样倍率
#         }

#         # 2. 预加载所有 Label 到内存 (模拟官方 Cache，极大提升速度)
#         self.labels = self.cache_labels()
#         print(f"Loaded {len(self.img_files)} images and labels.")

#     def cache_labels(self):
#         """一次性读取所有txt，避免getitem时反复IO"""
#         cache = []
#         from tqdm import tqdm
#         print("正在预加载所有标签，请耐心等待...")
#         for img_path in tqdm(self.img_files):
#             label_path = self.lbl_dir / (img_path.stem + ".txt")
#             lb = {'cls': [], 'bboxes': [], 'segments': []}
            
#             if label_path.exists():
#                 with open(label_path, 'r') as f:
#                     for line in f:
#                         items = list(map(float, line.split()))
#                         cls = items[0]
#                         # 读取分割点 (normalize 0-1)
#                         segment = np.array(items[1:], dtype=np.float32).reshape(-1, 2)
                        
#                         # segment 转 bbox (xywh normalized)
#                         bbox = self.segment2bbox(segment)
                        
#                         lb['cls'].append(cls)
#                         lb['bboxes'].append(bbox)
#                         lb['segments'].append(segment)
            
#             # 转为 numpy
#             lb['cls'] = np.array(lb['cls'], dtype=np.float32).reshape(-1, 1)
#             lb['bboxes'] = np.array(lb['bboxes'], dtype=np.float32).reshape(-1, 4)
#             # segments 长度不一，保持 list 结构
            
#             cache.append(lb)
#         return cache

#     def __len__(self):
#         return len(self.img_files)

#     def __getitem__(self, index):
#         # 1. Load Image
#         img_path = self.img_files[index]
#         label = self.labels[index].copy() # 浅拷贝，防止修改 cache
        
#         img = cv2.imread(str(img_path))
#         if img is None:
#             # 容错处理
#             return self.__getitem__((index + 1) % len(self))
        
#         h0, w0 = img.shape[:2]
        
#         # ================= 数据增强核心部分 =================
#         if self.augment:
#             # A. HSV 颜色增强
#             self.augment_hsv(img)
            
#             # B. 左右翻转 (Fliplr)
#             if random.random() < self.hyp['fliplr']:
#                 img = np.fliplr(img)
#                 # 翻转 label: x_center = 1 - x_center
#                 if len(label['bboxes']):
#                     label['bboxes'][:, 0] = 1 - label['bboxes'][:, 0]
#                     for i in range(len(label['segments'])):
#                         label['segments'][i][:, 0] = 1 - label['segments'][i][:, 0]

#             #  Mosaic (马赛克)
#             # 官方在这里会进行 Mosaic，这对小目标和防过拟合至关重要。
#             # 手写 Mosaic 比较复杂，如果指标还不够，建议直接继承官方类。

#         # ==================================================

#         # 3. Resize (Letterbox)
#         img, ratio, (dw, dh) = self.letterbox(img, self.img_size)
        
#         # 4. Generate Masks (从 segment 实时生成 mask)
#         # 注意：这里不需要手动 padded segments，因为我们用 list 传给 collate_fn
#         segments = label['segments']
#         masks = np.zeros((self.img_size // self.hyp['mask_ratio'], 
#                           self.img_size // self.hyp['mask_ratio']), dtype=np.float32)

#         # 处理 labels (转回 absolute 像素坐标用于 mask 生成，或者直接用 normalized 生成)
#         # 简单起见，这里演示如何生成重叠 Mask
#         if len(segments) > 0:
#              # 将 normalized segments 映射到 letterbox 后的图片尺寸
#             valid_segments = []
#             for s in segments:
#                 # 0-1 -> 原图尺寸
#                 s_px = s * np.array([w0, h0]) 
#                 # Resize + Pad
#                 s_px = s_px * ratio + np.array([dw, dh])
#                 valid_segments.append(s_px)
            
#             # 生成 Mask (注意官方是用 downsample 后的尺寸)
#             masks, sorted_idx = self.polygons2masks_overlap(
#                 (self.img_size, self.img_size), 
#                 valid_segments, 
#                 downsample_ratio=self.hyp['mask_ratio']
#             )
            
#             # 根据 mask 的排序重新排序 box 和 cls
#             label['cls'] = label['cls'][sorted_idx]
#             label['bboxes'] = label['bboxes'][sorted_idx]
#         else:
#             masks = masks[None] # (1, h, w) 空 mask

#         # 5. Image To Tensor
#         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR -> RGB
#         img = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0

#         return {
#             "img": img,
#             "cls": torch.from_numpy(label['cls']),
#             "bboxes": torch.from_numpy(label['bboxes']),
#             "masks": torch.from_numpy(masks),
#             "batch_idx": torch.tensor([index]) # 占位，collate_fn 会重写
#         }

#     def augment_hsv(self, img):
#         # 简单的 HSV 增强实现
#         h = self.hyp['hsv_h']
#         s = self.hyp['hsv_s']
#         v = self.hyp['hsv_v']
        
#         r = np.random.uniform(-1, 1, 3) * [h, s, v] + 1
#         hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
#         dtype = img.dtype  # uint8

#         x = np.arange(0, 256, dtype=r.dtype)
#         lut_hue = ((x * r[0]) % 180).astype(dtype)
#         lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
#         lut_val = np.clip(x * r[2], 0, 255).astype(dtype)

#         im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
#         cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img)

#     def letterbox(self, im, new_shape=640, color=(114, 114, 114)):
#         shape = im.shape[:2]  # current shape [height, width]
#         if isinstance(new_shape, int):
#             new_shape = (new_shape, new_shape)

#         r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
#         new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
#         dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
#         dw, dh = dw / 2, dh / 2  # divide padding into 2 sides

#         if shape[::-1] != new_unpad:  # resize
#             im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
        
#         top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
#         left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
#         im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
#         return im, r, (left, top)

#     def polygons2masks_overlap(self, imgsz, segments, downsample_ratio=1):
#         """参考官方逻辑，生成重叠 Mask"""
#         nh, nw = imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio
#         masks = np.zeros((nh, nw), dtype=np.uint8)
        
#         # 按照面积从大到小排序，确保小物体覆盖大物体（可选）
#         areas = [cv2.contourArea(s.astype(np.float32)) for s in segments]
#         sorted_idx = np.argsort(areas)[::-1] # 降序还是升序看具体需求，官方通常处理重叠是直接叠加
        
#         # 官方逻辑里，mask 值对应物体的 index (1, 2, 3...)
#         # 但 v8Loss 实际上需要 Binary Masks 或者特定格式
#         # 这里我们生成 (N, H, W) 的 Binary Masks 堆叠
        
#         masks_list = []
#         for i in sorted_idx:
#             seg = segments[i]
#             # 缩放到 mask 尺寸
#             seg = (seg / downsample_ratio).astype(np.int32)
#             m = np.zeros((nh, nw), dtype=np.uint8)
#             cv2.fillPoly(m, [seg], 1)
#             masks_list.append(m)
            
#         if len(masks_list) > 0:
#             return np.stack(masks_list, axis=0), sorted_idx
#         else:
#             return np.zeros((0, nh, nw)), []

#     def segment2bbox(self, segment):
#         # 简单的 xywh 计算
#         x = segment[:, 0]
#         y = segment[:, 1]
#         return np.array([(x.min() + x.max()) / 2, (y.min() + y.max()) / 2,
#                          x.max() - x.min(), y.max() - y.min()], dtype=np.float32)

#     @staticmethod
#     def collate_fn(batch):
#         """
#         对齐官方格式的关键！！
#         官方 Loss 需要: batch_idx, cls, bboxes, masks
#         """
#         img = torch.stack([b['img'] for b in batch], 0)
        
#         # 拼接列表
#         batch_idx = []
#         cls_list = []
#         bboxes_list = []
#         masks_list = []
        
#         for i, b in enumerate(batch):
#             n = len(b['cls'])
#             if n > 0:
#                 # 生成 batch_index: [0, 0, 0, 1, 1, 2...]
#                 batch_idx.append(torch.full((n, 1), i, dtype=torch.float32))
#                 cls_list.append(b['cls'])
#                 bboxes_list.append(b['bboxes'])
#                 masks_list.append(b['masks'])
        
#         if len(batch_idx) > 0:
#             return {
#                 'img': img,
#                 'batch_idx': torch.cat(batch_idx, 0),
#                 'cls': torch.cat(cls_list, 0),
#                 'bboxes': torch.cat(bboxes_list, 0),
#                 'masks': torch.cat(masks_list, 0)
#             }
#         else:
#             # 防止空 Batch 报错
#             return {
#                 'img': img,
#                 'batch_idx': torch.zeros((0, 1)),
#                 'cls': torch.zeros((0, 1)),
#                 'bboxes': torch.zeros((0, 4)),
#                 'masks': torch.zeros((0, 160, 160)) # 假设 640/4
#             }

class YOLODataset(Dataset):
    def __init__(self, img_dir, img_size=640, augment=True, hyp=None):
        """
        Args:
            augment: 是否开启增强 (训练时True, 验证时False)
            hyp: 超参数字典 (hsv_h, hsv_s, hsv_v, fliplr 等)
        """
        self.img_dir = Path(img_dir)
        # 假设目录结构是 images/.. 和 labels/..
        self.lbl_dir = self.img_dir.parent.parent / "labels" / self.img_dir.parts[-1]
        
        # 1. 找到所有图片
        self.img_files = sorted(list(self.img_dir.glob("*.jpg")) + list(self.img_dir.glob("*.png")))
        self.img_size = img_size
        self.augment = augment
        
        # 默认超参数 (参考官方 default.yaml)
        self.hyp = hyp if hyp else {
            'hsv_h': 0.015, 'hsv_s': 0.7, 'hsv_v': 0.4, # HSV 增强幅度
            'degrees': 0.0, 'translate': 0.1, 'scale': 0.5, # 几何增强
            'fliplr': 0.5,  # 左右翻转概率
            'mask_ratio': 4 # Mask 下采样倍率
        }

        # 2. 预加载所有 Label 到内存 (模拟官方 Cache，极大提升速度)
        self.labels = self.cache_labels()
        print(f"Loaded {len(self.img_files)} images and labels.")

    def cache_labels(self):
        """一次性读取所有txt，避免getitem时反复IO"""
        cache = []
        from tqdm import tqdm
        print("正在预加载所有标签，请耐心等待...")
        for img_path in tqdm(self.img_files):
            label_path = self.lbl_dir / (img_path.stem + ".txt")
            lb = {'cls': [], 'bboxes': [], 'segments': []}
            
            if label_path.exists():
                with open(label_path, 'r') as f:
                    for line in f:
                        items = list(map(float, line.split()))
                        cls = items[0]
                        # 读取分割点 (normalize 0-1)
                        segment = np.array(items[1:], dtype=np.float32).reshape(-1, 2)
                        
                        # segment 转 bbox (xywh normalized)
                        # 注意：这里需要你定义 self.segment2bbox，如果你没有，可以用 utils.general 中的实现
                        # 这里假设你已经有这个函数或者逻辑
                        bbox = self.segment2bbox(segment) 
                        
                        lb['cls'].append(cls)
                        lb['bboxes'].append(bbox)
                        lb['segments'].append(segment)
            
            # 转为 numpy
            lb['cls'] = np.array(lb['cls'], dtype=np.float32).reshape(-1, 1)
            lb['bboxes'] = np.array(lb['bboxes'], dtype=np.float32).reshape(-1, 4)
            # segments 长度不一，保持 list 结构
            
            cache.append(lb)
        return cache

    def segment2bbox(self, segment):
        """简单的 segment 转 xywh bbox (归一化)"""
        x_min, y_min = segment.min(axis=0)
        x_max, y_max = segment.max(axis=0)
        x_center = (x_min + x_max) / 2
        y_center = (y_min + y_max) / 2
        w = x_max - x_min
        h = y_max - y_min
        return np.array([x_center, y_center, w, h], dtype=np.float32)

    def letterbox(self, img, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True):
        """
        Letterbox Resize 函数
        注意：这里只是一个简单实现，为了代码完整性补充。
        实际使用请确保你的 letterbox 实现与此逻辑一致（返回 ratio, (dw, dh)）。
        """
        shape = img.shape[:2]  # current shape [height, width]
        if isinstance(new_shape, int):
            new_shape = (new_shape, new_shape)

        # Scale ratio (new / old)
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
        if not scaleup:  # only scale down, do not scale up (for better test mAP)
            r = min(r, 1.0)

        # Compute padding
        ratio = r, r  # width, height ratios
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
        
        dw /= 2  # divide padding into 2 sides
        dh /= 2

        if shape[::-1] != new_unpad:  # resize
            img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
        
        top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
        left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
        img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
        
        return img, ratio, (dw, dh)

    # 假设你有 polygons2masks_overlap 函数
    def polygons2masks_overlap(self, img_size, segments, downsample_ratio=1):
        """简单占位，实际需要用 utils.general 或 pycocotools 实现"""
        masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio), dtype=np.float32)
        # 这里需要真正的 mask 生成逻辑，为保持代码简洁省略
        # 返回: masks (np.array), sorted_idx (list/array)
        # 假设这里只是直接返回，实际需要排序
        return masks, np.arange(len(segments))

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, index):
        # 1. Load Image
        img_path = self.img_files[index]
        label = self.labels[index].copy() # 浅拷贝，防止修改 cache
        
        img = cv2.imread(str(img_path))
        if img is None:
            # 容错处理
            return self.__getitem__((index + 1) % len(self))
        
        h0, w0 = img.shape[:2]
        
        # ================= 数据增强核心部分 =================
        if self.augment:
            # A. HSV 颜色增强 (需实现 augment_hsv)
            # self.augment_hsv(img) 
            
            # B. 左右翻转 (Fliplr)
            if random.random() < self.hyp['fliplr']:
                img = np.fliplr(img)
                # 翻转 label: x_center = 1 - x_center
                if len(label['bboxes']):
                    label['bboxes'][:, 0] = 1 - label['bboxes'][:, 0]
                    for i in range(len(label['segments'])):
                        label['segments'][i][:, 0] = 1 - label['segments'][i][:, 0]

            #  Mosaic (马赛克)
            # ...

        # ==================================================
        # 3. Resize (Letterbox)
        img, ratio, (dw, dh) = self.letterbox(img, self.img_size)
        # 原代码缺失了这一步，导致 Box 是基于原图比例，而图片加了灰边，产生偏移。
        if len(label['bboxes']) > 0:
            bboxes = label['bboxes'].copy()
            bboxes[:, [0, 2]] *= w0
            bboxes[:, [1, 3]] *= h0
            # 3.2 缩放 (Scale): 乘以 Letterbox 的缩放比例
            if isinstance(ratio, tuple):
                 bboxes[:, [0, 2]] *= ratio[0]
                 bboxes[:, [1, 3]] *= ratio[1]
            else:
                 bboxes *= ratio
            # 3.3 平移 (Pad): 加上灰条偏移量
            # 只有中心点坐标 x, y 需要加偏移，宽高 w, h 不需要！
            bboxes[:, 0] += dw
            bboxes[:, 1] += dh
            # 3.4 再次归一化：转回相对于新图 (img_size) 的 (0-1) 坐标
            bboxes /= self.img_size  # 假设输出是正方形，否则分别除以 w_new, h_new
            np.clip(bboxes, 0, 1, out=bboxes)
            label['bboxes'] = bboxes
        # ======================================================================

        # 4. Generate Masks (从 segment 实时生成 mask)
        segments = label['segments']
        masks = np.zeros((self.img_size // self.hyp['mask_ratio'], 
                          self.img_size // self.hyp['mask_ratio']), dtype=np.float32)

        if len(segments) > 0:
             # 将 normalized segments 映射到 letterbox 后的图片尺寸
            valid_segments = []
            for s in segments:
                # 0-1 -> 原图尺寸
                s_px = s * np.array([w0, h0]) 
                # Resize + Pad (这里的逻辑是对的，所以你的 mask 没问题)
                if isinstance(ratio, tuple):
                    s_px = s_px * np.array(ratio) + np.array([dw, dh])
                else:
                    s_px = s_px * ratio + np.array([dw, dh])
                valid_segments.append(s_px)
            
            # 生成 Mask
            masks, sorted_idx = self.polygons2masks_overlap(
                (self.img_size, self.img_size), 
                valid_segments, 
                downsample_ratio=self.hyp['mask_ratio']
            )
            
            # 根据 mask 的排序重新排序 box 和 cls
            label['cls'] = label['cls'][sorted_idx]
            label['bboxes'] = label['bboxes'][sorted_idx]
        else:
            masks = masks[None] # (1, h, w) 空 mask

        # 5. Image To Tensor
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR -> RGB
        img = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0

        return {
            "img": img,
            "cls": torch.from_numpy(label['cls']),
            "bboxes": torch.from_numpy(label['bboxes']),
            "masks": torch.from_numpy(masks),
            "batch_idx": torch.tensor([index]) 
        }

    # =========================================================================
    # 工具函数 (保持原样或微调)
    # =========================================================================
    def augment_hsv(self, img):
        h = self.hyp['hsv_h']
        s = self.hyp['hsv_s']
        v = self.hyp['hsv_v']
        r = np.random.uniform(-1, 1, 3) * [h, s, v] + 1
        hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
        dtype = img.dtype
        x = np.arange(0, 256, dtype=r.dtype)
        lut_hue = ((x * r[0]) % 180).astype(dtype)
        lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
        lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
        im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
        cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img)

    def letterbox(self, im, new_shape=640, color=(114, 114, 114)):
        shape = im.shape[:2]
        if isinstance(new_shape, int): new_shape = (new_shape, new_shape)
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
        dw, dh = dw / 2, dh / 2
        if shape[::-1] != new_unpad: 
            im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
        top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
        left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
        im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
        return im, r, (left, top)

    def polygons2masks_overlap(self, imgsz, segments, downsample_ratio=1):
        nh, nw = imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio
        masks = np.zeros((nh, nw), dtype=np.uint8)
        # 按面积排序确保小物体在上（可选，这里保持你的逻辑）
        areas = [cv2.contourArea(s.astype(np.float32)) for s in segments]
        sorted_idx = np.argsort(areas)[::-1] 
        
        masks_list = []
        for i in sorted_idx:
            seg = segments[i]
            # 缩放到 mask 尺寸
            seg = (seg / downsample_ratio).astype(np.int32)
            m = np.zeros((nh, nw), dtype=np.uint8)
            cv2.fillPoly(m, [seg], 1)
            masks_list.append(m)
            
        if len(masks_list) > 0:
            return np.stack(masks_list, axis=0), sorted_idx
        else:
            return np.zeros((0, nh, nw)), []

    def segment2bbox(self, segment):
        # segment (N, 2)
        x = segment[:, 0]
        y = segment[:, 1]
        return np.array([(x.min() + x.max()) / 2, (y.min() + y.max()) / 2,
                         x.max() - x.min(), y.max() - y.min()], dtype=np.float32)

    @staticmethod
    def collate_fn(batch):
        # 保持你原来的 collate_fn 不变
        img = torch.stack([b['img'] for b in batch], 0)
        batch_idx = []
        cls_list = []
        bboxes_list = []
        masks_list = []
        for i, b in enumerate(batch):
            n = len(b['cls'])
            if n > 0:
                batch_idx.append(torch.full((n, 1), i, dtype=torch.float32))
                cls_list.append(b['cls'])
                bboxes_list.append(b['bboxes'])
                masks_list.append(b['masks'])
        if len(batch_idx) > 0:
            return {
                'img': img,
                'batch_idx': torch.cat(batch_idx, 0),
                'cls': torch.cat(cls_list, 0),
                'bboxes': torch.cat(bboxes_list, 0),
                'masks': torch.cat(masks_list, 0)
            }
        else:
            h, w = img.shape[2] // 4, img.shape[3] // 4
            return {
                'img': img,
                'batch_idx': torch.zeros((0, 1)),
                'cls': torch.zeros((0, 1)),
                'bboxes': torch.zeros((0, 4)),
                'masks': torch.zeros((0, h, w))
            }
        
# def collate_fn(batch):
#     imgs = torch.stack([b["img"] for b in batch])

#     targets_list = []
#     masks_list = []
#     for i, b in enumerate(batch):
#         label_data = b['labels']
#         if label_data and len(label_data) == 4:
#             classes = label_data['classes']
#             masks = label_data['masks']
#             bboxes = label_data['bboxes']

#             n_objects = len(classes)
#             if n_objects > 0 :
#                 idx_col = torch.full((n_objects, 1), i, dtype=torch.float32)
#                 cls_col = classes.view(-1, 1).float()
#                 targets = torch.cat([idx_col, cls_col, bboxes], dim=1)
#                 targets_list.append(targets)
#                 masks_list.append(masks)

#     if targets_list:
#         targets_out = torch.cat(targets_list, dim=0)
#         masks_out = torch.cat(masks_list, dim=0)
#     else:
#         targets_out = torch.zeros((0, 6), dtype=torch.float32)
#         masks_out = torch.zeros((0, 160, 160), dtype=torch.float32)

#     return {
#         "img": imgs,
#         "labels": {
#             "targets_out": targets_out,
#             "masks_out": masks_out,
#         }
#     }
def collate_fn(batch):
    # 1. 堆叠图片 (Batch, 3, H, W)
    img = torch.stack([b['img'] for b in batch], 0)
    
    # 2. 准备列表
    batch_idx = []
    cls_list = []
    bboxes_list = []
    masks_list = []
    
    # 3. 遍历 Batch
    for i, b in enumerate(batch):
        # 修正点：直接访问 b['cls']，而不是 b['labels']['cls']
        n = len(b['cls'])
        
        if n > 0:
            # 生成 batch_index
            batch_idx.append(torch.full((n, 1), i, dtype=torch.float32))
            cls_list.append(b['cls'])
            bboxes_list.append(b['bboxes'])
            masks_list.append(b['masks'])
    
    # 4. 拼接
    if len(batch_idx) > 0:
        return {
            'img': img,
            'batch_idx': torch.cat(batch_idx, 0),
            'cls': torch.cat(cls_list, 0),
            'bboxes': torch.cat(bboxes_list, 0),
            'masks': torch.cat(masks_list, 0)
        }
    else:
        # 防止空 Batch
        h, w = img.shape[2] // 4, img.shape[3] // 4
        return {
            'img': img,
            'batch_idx': torch.zeros((0, 1)),
            'cls': torch.zeros((0, 1)),
            'bboxes': torch.zeros((0, 4)),
            'masks': torch.zeros((0, h, w))
        }

# def criterion(preds, batch):
#     compute_loss = SegLoss(overlap=True)
#     return compute_loss(preds, batch)
def criterion(preds, batch):
    compute_loss = SegLoss(overlap=True)
    
    # 适配器逻辑：如果有 batch_idx 说明是新格式
    if 'batch_idx' in batch:
        targets_out = torch.cat(
            [batch['batch_idx'], batch['cls'], batch['bboxes']], 
            dim=1
        )
        loss_batch = {
            'img': batch['img'], 
            'labels': {
                'targets_out': targets_out, 
                'masks_out': batch['masks']
            }
        }
        return compute_loss(preds, loss_batch)
        
    return compute_loss(preds, batch)



def validate_loss(model, data_loader, device):
    """
    专门用来在验证集上跑一遍，只计算 Loss，不更新梯度
    """
    model.eval() # 切换到评估模式
    avgloss = AverageMeter("Val Loss", ":1.5f")
    
    # 进度条 (可选)
    # pbar = tqdm(enumerate(data_loader), total=len(data_loader), desc="Validating", leave=False)
    
    with torch.no_grad(): # 这一步很关键，省显存且不记录梯度
        for i, batch in enumerate(data_loader):
            # 1. 数据搬运到 GPU
            image = batch["img"].to(device)
            for k, v in batch.items():
                if isinstance(v, torch.Tensor):
                    batch[k] = v.to(device)
            
            # 2. 前向传播
            preds = model(image)
            
            # 3. 计算 Loss
            loss, loss_items = criterion(preds, batch)
            
            # 4. 记录
            if not (math.isinf(loss.item()) or math.isnan(loss.item())):
                avgloss.update(loss.item(), image.size(0))
                
    print(f"  >>> Validation Loss: {avgloss.avg:.5f}")
    return avgloss.avg