import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
import json
import cv2
from horizon_tc_ui import HB_ONNXRuntime


def softmax_2D(X):
    """
    针对二维numpy矩阵每一行进行softmax操作
    X: np.array. Probably should be floats.
    return: 二维矩阵
    """
    # looping through rows of X
    #   循环遍历X的行
    ps = np.empty(X.shape)
    for i in range(X.shape[0]):
        ps[i,:]  = np.exp(X[i,:])
        ps[i,:] /= np.sum(ps[i,:])
    return ps


def check_onnx(onnx_model, img, json_path, input_shape):
    ## --------------------------------------------#
    ##  opencv实现预处理方式
    ## --------------------------------------------#
    img = cv2.resize(np.array(img), input_shape, interpolation=cv2.INTER_CUBIC).astype(np.float32)
    # img /= 255.0        # 要在减均值，除方差之前

    # 网络训练输入一般是RGB的图片，故在此也转一下
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img -= [124.16, 116.28, 103.53]
    img *= [0.0171248, 0.0175070, 0.0174292]

    # 原本的img是HWC格式，quantized_onnx_model输入格式NHWC
    #   optimized_float_onnx_model输入格式NCHW
    
    if "optimized" in onnx_model:
        img = img.transpose(2, 0, 1)    # 从HWC，变为CHW
    # 添加batch维度
    img = np.expand_dims(img, 0)


    # -----------------------------------#
    #   class_indict用于可视化类别
    # -----------------------------------#
    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # 加载onnx模型
    ort_session = HB_ONNXRuntime(model_file=onnx_model)
    # 这两行有啥用？我注释掉也没什么影响
    #ort_session_2.set_dim_param(0, 0, '?')     
    #ort_session_2.set_providers(['CPUExecutionProvider'])
    
    # -----------------------------------#
    # onnx模型推理
    # 初始化数据，注意此时img是numpy格式
    # -----------------------------------#
    input_name = ort_session.input_names[0]
    ort_outs = ort_session.run(None, {input_name: img})        # 推理得到输出
    # print(ort_outs)     # [array([[-4.290639  , -2.267056  ,  7.666328  , -1.4162455 ,  0.57391334]], dtype=float32)]
    
    # -----------------------------------#
    # 经过softmax转化为概率
    # softmax_2D按行转化，一行一个样本
    #   测试时，softmax可不要！
    # -----------------------------------#
    predict_probability = softmax_2D(ort_outs[0])        
    # print(predict_probability)  # array([[0.1],[0.2],[0.3],[0.3],[0.1]])           
    
    # -----------------------------------#
    # argmax得到最大概率索引，也就是类别对应索引
    # -----------------------------------#
    predict_cla = np.argmax(predict_probability, axis=-1)
    # print(predict_cla)        # array([2])

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla[0])],
                                                 predict_probability[0][predict_cla[0]])
    print(print_res)
    plt.title(print_res)
    for i in range(len(predict_probability[0])):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict_probability[0][i]))
    plt.savefig("./result.jpg")
    plt.show()

if __name__ == '__main__':
    # 注意，quantized_onnx_model的输入格式为：NHWC
    #   optimized_float_onnx_model的输入格式为：NCHW
    #   这儿的命名一定要以quantized和optimized进行区分，后面的代码中有用到
    quantized_onnx_model = './model_output/resnet34_224x224_rgb_quantized_model.onnx'
    optimized_float_onnx_model = './model_output/resnet34_224x224_rgb_optimized_float_model.onnx'

    input_shape = (224, 224)
    img_path = "./data/rose111.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    # 下面两行，用它显示图片而已
    img = Image.open(img_path)
    plt.imshow(img)     

    img = cv2.imread(img_path)
    # read class_indict
    json_path = './class_indices.json'

    check_onnx(quantized_onnx_model, img, json_path, input_shape)
    print("onnx model check finsh.")
