import mxnet as mx
import numpy as np
from typing import Any, Tuple, Union
try:
    from .utilities import change_image_type, define_network, get_mxnet_context, load_parameters
except ImportError:
    from utilities import change_image_type, define_network, get_mxnet_context, load_parameters

class Model:
    ctx = get_mxnet_context()

    def __init__(self, parameters: Union[dict, str], model_file_path: str):
        if type(parameters) is str:
            parameters = load_parameters(parameters)
        g = (parameters[k] for k in ('is_convolutional', 'outputs'))
        self.__is_convolutional, self.__outputs = g
        self.__outputs = sorted(self.__outputs.values())
        net = define_network(self.__is_convolutional, len(self.__outputs))
        net.load_parameters(model_file_path, ctx=Model.ctx)
        self.__net = net

    def __str__(self) -> str:
        return str(self.__net)

    def predict(self, image: np.array) -> Tuple[Any, float]:
        image = change_image_type(image)
        if self.__is_convolutional:
            image = np.moveaxis(image, -1, 0)
        data = mx.ndarray.array([image]).as_in_context(Model.ctx)
        output = self.__net(data)[0]
        best_option = np.argmax(output.asnumpy())
        return self.__outputs[best_option], float(mx.nd.softmax(output)[best_option].asnumpy())

    @staticmethod
    def use_cpu() -> None:
        Model.ctx = mx.cpu()

if __name__ == '__main__':
    import cv2
    import sys
    model = Model(sys.argv[1], sys.argv[2])
    image = cv2.imread(sys.argv[3])
    result = model.predict(image)
    print(*result)
