# import argparse
import os

import copy
import torch
import torch.nn as nn
import torch.quantization
from torch import Tensor
from torch.quantization import DeQuantStub
from torchvision.models.mobilenetv2 import (
    InvertedResidual,
    MobileNetV2,
)

from torchvision.ops.misc import ConvNormActivation
from horizon_plugin_pytorch.march import March, set_march
from horizon_plugin_pytorch.quantization import (
    QuantStub,
    convert,
    fuse_known_modules,
    prepare_qat,
)
from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,
    default_qat_8bit_fake_quant_qconfig,
    default_qat_8bit_weight_32bit_out_fake_quant_qconfig,
    default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
)
from typing import Optional, Callable, List
from horizon_plugin_pytorch.nn.quantized import FloatFunctional
from common import *

# Specify random seed for repeatable results
torch.manual_seed(191009)


##############################################################################
# At first, we do necessary modify to the MobilenetV2 model from torchvision.
# 1. Insert QuantStub before first layer and DequantStub after last layer.
# 2. Replace unsupported torch func ops with their nn.Module counterpart.
#    In this case, we
#    a. replace the plus sign with `FloatFunctional().add`.
#    b. replace `nn.functional.adaptive_avg_pool2d`
#       with `nn.AdaptiveAvgPool2d`.
# 3. Manually define the ops to be fused (by Module's name).
#    All availiable fuse patterns can be accessed by
#    `horizon_plugin_pytorch.quantization.fuse_modules.get_op_list_to_fuser_mapping()`.
#    Note: User should fuse as many ops as possible, or model accuracy and
#          execution speed will be effected.
#
##############################################################################


class EagerQATReadyInvertedResidual(InvertedResidual):
    def __init__(
        self,
        inp: int,
        oup: int,
        stride: int,
        expand_ratio: int,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__(inp, oup, stride, expand_ratio, norm_layer)

        if self.use_res_connect:
            # Must register the FloatFunctional as submodule,
            # or the quantization state will not be handled correctly.
            self.skip_add = FloatFunctional()

    def forward(self, x: Tensor) -> Tensor:
        if self.use_res_connect:
            return self.skip_add.add(self.conv(x), x)
        else:
            return self.conv(x)

    def fuse_model(self):
        for idx in range(len(self.conv)):
            if type(self.conv[idx]) == nn.Conv2d:
                if not self.use_res_connect:
                    # Fuse conv+bn
                    torch.quantization.fuse_modules(
                        self.conv,
                        [str(idx), str(idx + 1)],
                        inplace=True,
                        fuser_func=fuse_known_modules,
                    )
                else:
                    # Fuse conv+bn+add
                    torch.quantization.fuse_modules(
                        self,
                        [
                            "conv." + str(idx),
                            "conv." + str(idx + 1),
                            "skip_add",
                        ],
                        inplace=True,
                        fuser_func=fuse_known_modules,
                    )


class EagerQATReadyMobileNetV2(MobileNetV2):
    def __init__(
        self,
        num_classes: int = 10,
        width_mult: float = 0.5,
        inverted_residual_setting: Optional[List[List[int]]] = None,
        round_nearest: int = 8,
    ):
        super().__init__(
            num_classes,
            width_mult,
            inverted_residual_setting,
            round_nearest,
            block=EagerQATReadyInvertedResidual,
        )
        # Horizon QuantStub support user-specified scale.
        # The `input_source` param of `compile_model` can be set to "pyramid"
        # only if input scale is equal to 1/128.
        self.quant = QuantStub(scale=1 / 128)
        self.dequant = DeQuantStub()
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x: Tensor) -> Tensor:
        x = self.quant(x)
        x = self.features(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)  # torch.flatten is supported
        x = self.classifier(x)
        x = self.dequant(x)

        return x

    def fuse_model(self):
        for m in self.modules():
            if isinstance(m, ConvNormActivation):
                # Fuse conv+bn+relu
                torch.quantization.fuse_modules(
                    m,
                    ["0", "1", "2"],
                    inplace=True,
                    fuser_func=fuse_known_modules,
                )
            if type(m) == EagerQATReadyInvertedResidual:
                m.fuse_model()


def get_model_eager(
    stage: str,
    model_path: str,
    device: torch.device,
    march=March.BAYES,
) -> nn.Module:
    assert stage in ("float", "calib", "qat", "int_infer", "compile")
    model_kwargs = dict(num_classes=10, width_mult=1.0)


    float_model = EagerQATReadyMobileNetV2(**model_kwargs).to(device)

    if stage == "float":
        # We also could use the origin MobileNetV2 model for float training,
        # because modified QAT ready model can load its params seamlessly.
        float_model = MobileNetV2(**model_kwargs).to(
            device
        )  # these lines are optional

        # Load pretrained model (on ImageNet) to speed up float training.
        load_pretrain(float_model, model_path)

        return float_model

    float_ckpt_path = os.path.join(model_path, "float-checkpoint.ckpt")
    assert os.path.exists(float_ckpt_path)
    float_state_dict = torch.load(float_ckpt_path, map_location=device)

    # A global march indicating the target hardware version must be setted
    # before prepare qat.
    set_march(march)

    # Preserve a clean float_model for calibration and qat training.
    ori_float_model = float_model
    float_model = copy.deepcopy(ori_float_model)

    float_model.load_state_dict(float_state_dict)
    # Manually do op fusion.
    float_model.fuse_model()
    float_model.qconfig = default_calib_8bit_fake_quant_qconfig
    # If the last layer is Linear* or Conv*, we can config its output
    # to int32 to get better accuracy.
    float_model.classifier.qconfig = (
        default_calib_8bit_weight_32bit_out_fake_quant_qconfig
    )
    # Make sure the output model is on target device.
    # CAUTION: prepare_qat* and convert* do not guarantee the
    # output model is on the same device as input model.
    calib_model = prepare_qat(float_model).to(device)


    if stage == "calib":
        return calib_model

    calib_ckpt_path = os.path.join(model_path, "calib-checkpoint.ckpt")
    if os.path.exists(calib_ckpt_path):
        calib_state_dict = torch.load(calib_ckpt_path, map_location=device)
        float_state_dict = None
    else:
        calib_state_dict = None

    float_model = copy.deepcopy(ori_float_model)


    # QAT model can be generated either from float_model
    # or from calib_model.
    # Pay attention to the calling order of `load_state_dict`
    # and `prepare_qat`.
    if float_state_dict is not None:
        float_model.load_state_dict(float_state_dict)
    float_model.fuse_model()
    float_model.qconfig = default_qat_8bit_fake_quant_qconfig
    float_model.classifier.qconfig = (
        default_qat_8bit_weight_32bit_out_fake_quant_qconfig
    )
    qat_model = prepare_qat(float_model).to(device)
    if calib_state_dict is not None:
        qat_model.load_state_dict(calib_state_dict)

    if stage == "qat":
        return qat_model

    # int_infer and compile
    qat_ckpt_path = os.path.join(model_path, "qat-checkpoint.ckpt")
    if os.path.exists(qat_ckpt_path):
        qat_model.load_state_dict(
            torch.load(qat_ckpt_path, map_location=device)
        )
    elif os.path.exists(calib_ckpt_path):
        calib_model.load_state_dict(
            torch.load(calib_ckpt_path, map_location=device)
        )
        # The qat_model and calib_model both can be converted to
        # quantized model directly.
        # So we can do either calibration or qat training or both.
        qat_model = calib_model
    else:
        raise FileNotFoundError(
            "Do not find saved calib_model or qat_model ckpt "
            "to do int inference."
        )


    quantized_model = convert(qat_model).to(device)

    return quantized_model



if __name__ == "__main__":
    args = get_args()
    device = device = torch.device(
        "cuda:{}".format(args.device_id) if args.device_id >= 0 else "cpu"
    )
    model =  get_model_eager(args.stage, args.model_path, device)
    main(
        model,
        args.stage,
        args.data_path,
        args.model_path,
        args.train_batch_size,
        args.eval_batch_size,
        args.epoch_num,
        args.device_id,
        march=args.march,
        compile_opt=args.opt,
    )
