# 文件名: check_env.py
import torch
import os
import sys

# 尝试导入地平线工具
try:
    from horizon_plugin_pytorch.quantization import check_model
    print("✅ 成功导入 horizon_plugin_pytorch")
except ImportError:
    print("❌ 错误: 无法导入 horizon_plugin_pytorch。请检查环境是否激活 (conda activate qat_env)")
    sys.exit(1)

# 导入您的模型 Wrapper
try:
    from qat_wrapper import YOLOv8_Seg_QAT_Wrapper
except ImportError:
    print("❌ 错误: 找不到 qat_wrapper.py。请确保它在当前目录下。")
    sys.exit(1)

def main():
    # --- 配置 ---
    # 您的权重路径
    weight_path = "/data6/liuziyi/yolov8_qat/qat_project/best.pt"
    
    print(f"\n=== 1. 检查权重文件 ===")
    if os.path.exists(weight_path):
        print(f"✅ 找到权重文件: {weight_path}")
    else:
        print(f"❌ 错误: 找不到文件 {weight_path}")
        print("请确认路径是否正确。")
        return

    # --- 实例化 ---
    print("\n=== 2. 实例化 QAT Wrapper ===")
    try:
        model = YOLOv8_Seg_QAT_Wrapper(weight_path)
        print("✅ 模型实例化成功！")
        # 将模型设为 eval 模式
        model.eval()
    except Exception as e:
        print(f"❌ 模型实例化失败: {e}")
        import traceback
        traceback.print_exc()
        return

    # --- 检查结构 ---
    print("\n=== 3. 检查量化插桩 (Stubs) ===")
    has_quant = hasattr(model, 'quant')
    has_dequant = hasattr(model, 'dequant_detect') and hasattr(model, 'dequant_seg')
    
    if has_quant and has_dequant:
        print("✅ QuantStub (入口) 和 DeQuantStub (出口) 已检测到。")
    else:
        print(f"❌ 缺少 Stub! Input: {has_quant}, Output: {has_dequant}")

    # --- 前向传播测试 ---
    print("\n=== 4. 测试前向传播 (Forward) ===")
    try:
        fake_input = torch.randn(1, 3, 640, 640)
        print("正在运行 forward pass ...")
        
        # 运行模型
        out0, out1 = model(fake_input)
        
        print(f"✅ 前向传播成功！")
        print(f"   输出 1 (Detect) 尺寸: {out0.shape}")
        
        # --- 🔍 调试输出 2 (Segment) ---
        print(f"\n🔍 正在分析输出 2 (out1)...")
        print(f"   类型: {type(out1)}")
        
        if isinstance(out1, tuple) or isinstance(out1, list):
            print(f"   长度: {len(out1)}")
            for i, item in enumerate(out1):
                print(f"   -> 第 {i} 项类型: {type(item)}")
                if hasattr(item, 'shape'):
                    print(f"      Shape: {item.shape}")
                elif isinstance(item, list):
                     print(f"      是一个 List, 长度: {len(item)}")
                     if len(item) > 0 and hasattr(item[0], 'shape'):
                         print(f"      List[0] Shape: {item[0].shape}")
        elif hasattr(out1, 'shape'):
            print(f"   Shape: {out1.shape}")
        # ---------------------------------

    except Exception as e:
        print(f"❌ 前向传播失败: {e}")
        import traceback
        traceback.print_exc()

    print("\n" + "="*40)
    print("总结: 如果以上步骤都是 ✅，则环境配置完美！")
    print("="*40)

if __name__ == "__main__":
    main()