import argparse
import os

import copy
import torch
import torch.nn as nn
import torch.quantization
import torchvision.transforms as transforms
from torch import Tensor
from torch.quantization import DeQuantStub
from torchvision.models.mobilenetv2 import (
    MobileNetV2,
)
from horizon_plugin_pytorch.march import March, set_march
from horizon_plugin_pytorch.quantization import (
    QuantStub,
    convert_fx,
    prepare_qat_fx,
)
from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,
    default_qat_8bit_fake_quant_qconfig,
    default_qat_8bit_weight_32bit_out_fake_quant_qconfig,
    default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
)
from typing import Optional, Callable, List, Tuple
from horizon_plugin_pytorch.nn.quantized import FloatFunctional
from common import *

# Specify random seed for repeatable results
torch.manual_seed(191009)


##############################################################################
# At first, we do necessary modify to the MobilenetV2 model from torchvision.
# 1. Insert QuantStub before first layer and DequantStub after last layer.
# Operation replacement and fusion will be carried out automatically (^_^).
##############################################################################


class FxQATReadyMobileNetV2(MobileNetV2):
    def __init__(
        self,
        num_classes: int = 10,
        width_mult: float = 0.5,
        inverted_residual_setting: Optional[List[List[int]]] = None,
        round_nearest: int = 8,
    ):
        super().__init__(
            num_classes, width_mult, inverted_residual_setting, round_nearest
        )
        self.quant = QuantStub(scale=1 / 128)
        self.dequant = DeQuantStub()

    def forward(self, x: Tensor) -> Tensor:
        x = self.quant(x)
        x = super().forward(x)
        x = self.dequant(x)

        return x

def get_model_fx(
    stage: str,
    model_path: str,
    device: torch.device,
    march=March.BAYES,
) -> nn.Module:
    assert stage in ("float", "calib", "qat", "int_infer", "compile")
    model_kwargs = dict(num_classes=10, width_mult=1.0)


    float_model = FxQATReadyMobileNetV2(**model_kwargs).to(device)

    if stage == "float":
        # We also could use the origin MobileNetV2 model for float training,
        # because modified QAT ready model can load its params seamlessly.
        float_model = MobileNetV2(**model_kwargs).to(
            device
        )  # these lines are optional

        # Load pretrained model (on ImageNet) to speed up float training.
        load_pretrain(float_model, model_path)

        return float_model

    float_ckpt_path = os.path.join(model_path, "float-checkpoint.ckpt")
    assert os.path.exists(float_ckpt_path)
    float_state_dict = torch.load(float_ckpt_path, map_location=device)

    # A global march indicating the target hardware version must be setted
    # before prepare qat.
    set_march(march)

    # Preserve a clean float_model for calibration and qat training.
    ori_float_model = float_model
    float_model = copy.deepcopy(ori_float_model)

    float_model.load_state_dict(float_state_dict)
    # The op fusion is included in `prepare_qat_fx`.
    # We can eigher pass qconfig as dict in prepare_qat_fx, or set
    # module's qconfig attr in the same manner as eager mode.
    calib_model = prepare_qat_fx(
        float_model,
        {
            "": default_calib_8bit_fake_quant_qconfig,
            "module_name": {
                "classifier": default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
            },
        },
    ).to(device)

    if stage == "calib":
        return calib_model

    calib_ckpt_path = os.path.join(model_path, "calib-checkpoint.ckpt")
    if os.path.exists(calib_ckpt_path):
        calib_state_dict = torch.load(calib_ckpt_path, map_location=device)
        float_state_dict = None
    else:
        calib_state_dict = None

    float_model = copy.deepcopy(ori_float_model)

    if float_state_dict is not None:
        float_model.load_state_dict(float_state_dict)
    qat_model = prepare_qat_fx(
        float_model,
        {
            "": default_qat_8bit_fake_quant_qconfig,
            "module_name": {
                "classifier": default_qat_8bit_weight_32bit_out_fake_quant_qconfig,
            },
        },
    ).to(device)
    if calib_state_dict is not None:
        qat_model.load_state_dict(calib_state_dict)

    if stage == "qat":
        return qat_model

    qat_ckpt_path = os.path.join(model_path, "qat-checkpoint.ckpt")
    if os.path.exists(qat_ckpt_path):
        qat_model.load_state_dict(
            torch.load(qat_ckpt_path, map_location=device)
        )
    elif os.path.exists(calib_ckpt_path):
        calib_model.load_state_dict(
            torch.load(calib_ckpt_path, map_location=device)
        )
        # The qat_model and calib_model both can be converted to
        # quantized model directly.
        # So we can do either calibration or qat training or both.
        qat_model = calib_model
    else:
        raise FileNotFoundError(
            "Do not find saved calib_model or qat_model ckpt "
            "to do int inference."
        )

    quantized_model = convert_fx(qat_model).to(device)

    return quantized_model



if __name__ == "__main__":
    args = get_args()

    device = device = torch.device(
        "cuda:{}".format(args.device_id) if args.device_id >= 0 else "cpu"
    )
    model =  get_model_fx(args.stage, args.model_path, device)
    main(
        model,
        args.stage,
        args.data_path,
        args.model_path,
        args.train_batch_size,
        args.eval_batch_size,
        args.epoch_num,
        args.device_id,
        march=args.march,
        compile_opt=args.opt,
    )
