import torch
import argparse
import os
import sys

# ==============================================================================
# 【步骤 1】导入你的模型定义
# ==============================================================================
# 假设你原来的训练脚本文件名叫 qat_train.py (或者 fx_mode_yolov8.py)
# 请将下面的 'qat_train' 修改为你实际的文件名（注意不要带 .py 后缀）
try:
    # 修正：分别从不同的包导入 QuantStub 和 DeQuantStub
    from horizon_plugin_pytorch.quantization import QuantStub
    from torch.quantization import DeQuantStub
    
    # 导入你的模型类 (假设你的训练脚本叫 fx_mode_yolov8.py)
    # 如果你的文件名不是 fx_mode_yolov8，请修改这里！
    from fx_mode_yolov8 import YOLOv8Model
except ImportError as e:
    print(f"导入失败: {e}")
    print("请打开本脚本(export_float_onnx.py)，将 'from qat_train import YOLOv8Model'")
    print("修改为你存放模型定义的实际文件名。")
    sys.exit(1)

# ==============================================================================
# 【步骤 2】导出逻辑
# ==============================================================================
def export_float(args):
    print("=" * 50)
    print(f"准备导出浮点模型权重: {args.ckpt_path}")

    # 导出时建议使用 CPU，避免不必要的显存问题
    device = torch.device("cpu")

    # 1. 实例化模型
    # nc=9 是根据你提供的代码填写的，如果变了请修改
    print("实例化模型结构...")
    model = YOLOv8Model(nc=9).to(device)

    # 2. 加载权重 (float-checkpoint.ckpt)
    if not os.path.exists(args.ckpt_path):
        print(f"错误: 找不到权重文件 {args.ckpt_path}")
        return

    print("加载权重中...")
    ckpt = torch.load(args.ckpt_path, map_location=device)

    # 3. 权重清洗 (处理 state_dict 结构)
    # 检查权重是否包裹在 'state_dict' 或 'model' 键中
    if isinstance(ckpt, dict):
        if "state_dict" in ckpt:
            state_dict = ckpt["state_dict"]
        elif "model" in ckpt:
            state_dict = ckpt["model"]
        else:
            state_dict = ckpt
    else:
        # 假如直接保存的是 model 对象
        state_dict = ckpt.state_dict()

    # 处理 DDP 训练产生的 'module.' 前缀
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("module."):
            new_state_dict[k[7:]] = v
        else:
            new_state_dict[k] = v

    # 加载权重到模型
    # strict=False 可以容忍一些非关键参数的不匹配（如 quant scale 初始值等）
    try:
        model.load_state_dict(new_state_dict, strict=True)
        print("权重加载成功 (strict=True)")
    except Exception as e:
        print(f"权重加载有轻微不匹配，尝试 strict=False: {e}")
        model.load_state_dict(new_state_dict, strict=False)
        print("权重加载成功 (strict=False)")

    model.eval()

    # 4. 准备 Dummy Input
    # 假设输入是 640x640，根据你的 args.imgsz
    dummy_input = torch.randn(1, 3, args.imgsz, args.imgsz).to(device)

    # 5. 定义输出路径
    output_path = args.output_path
    if not output_path:
        # 默认保存在 ckpt 同级目录下
        output_path = os.path.join(os.path.dirname(args.ckpt_path), "yolov8_float.onnx")

    print(f"正在导出 ONNX 到: {output_path}")

    # 定义输出节点名称
    # 你的 forward 返回: (list_of_tensors, qtensor_2_dequant, qtensor_3_dequant)
    # 对应: [Small_Head_List...], Medium_Head, Large_Head
    # 具体有多少个输出取决于 list_of_tensors 里有几个 tensor
    
    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        input_names=["images"],
        # output_names=["output"], # 建议先注释掉，让它自动生成，后续用 Netron 查看
        opset_version=11,
        do_constant_folding=True,
    )

    print("导出完成！")
    print("=" * 50)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    # 权重路径
    parser.add_argument("--ckpt_path", type=str, 
                        default="/data6/liuziyi/yolov8_qat/qat_project/yolov8_multi_task/qat-checkpoint.ckpt",
                        help="训练生成的 float-checkpoint.ckpt 路径")
    
    # 输出路径 (可选)
    parser.add_argument("--output_path", type=str, default="/data6/liuziyi/yolov8_qat/qat_project/yolov8_multi_task/qat1202.onnx", help="导出的 onnx 保存路径")
    
    # 图片大小
    parser.add_argument("--imgsz", type=int, default=640)

    args = parser.parse_args()
    
    export_float(args)