import json
import logging
import traceback

from google.protobuf import json_format

from sandbox import sdk2
from sandbox.common import errors as sb_errors
from sandbox.common import patterns as sb_patterns
from sandbox.common.types import client as ctc
from sandbox.common.types import misc as ctm

from tasklet.api import tasklet_pb2
from tasklet.domain import oauth
from tasklet.runtime import context as rt_context
from tasklet.runtime import dispatch
from tasklet.runtime import utils
from tasklet.runtime.python import base as py_base
from tasklet.runtime.utils import convert


class BaseTasklet(sdk2.Task):
    """ Base task class for classes generated from tasklets """

    @sb_patterns.classproperty
    def __holder_cls__(cls):
        """
        Tasklet holder class. Parameter is overrided in tasklet's generated code.
        :rtype: tasklet.runtime.python.base.TaskletHolder
        """
        return None

    class Parameters(sdk2.Parameters):
        __tasklet_id__ = sdk2.parameters.String("Tasklet id", ui=None, do_not_copy=True)
        __run_id__ = sdk2.parameters.String("Run id", ui=None, do_not_copy=True)
        __tasklet_name__ = sdk2.parameters.String(
            "Tasklet name",
            description="Parameter to run not default tasklet implementation",
            ui=None,
        )

        with sdk2.parameters.Group("Tasklet runtime parameters"):
            __tasklet_secret__ = sdk2.parameters.YavSecret(
                "Secret with token (default key: {}) to get access to tasklet services.".format(
                    oauth.YAV_SECRET_DEFAULT_KEY_NAME
                ),
                description='It could be received via <a href="{}">oauth</a>.'.format(oauth.url_to_get_token()),
            )

        with sdk2.parameters.Group("Tasklet input parameters") as TaskletInputParameters:
            __tasklet_input__ = sdk2.parameters.JSON("Tasklet input", ui=None)

        with sdk2.parameters.Output(reset_on_restart=True):
            tasklet_result = sdk2.parameters.JSON("Tasklet result", ui=None)  # tasklet result including error messages
            __tasklet_output__ = sdk2.parameters.JSON("Tasklet output", ui=None)  # unpacked output

    class Requirements(sdk2.Requirements):

        cores = 1
        ram = 4096

        client_tags = ctc.Tag.LXC | ctc.Tag.PORTOD

        class Caches(sdk2.Requirements.Caches):
            pass

    class Context(sdk2.Context):
        initial_input_parameters = None
        debug_messages = []

    __server_side = False

    def _debug(self, message):
        if self.__server_side:
            self.Context.debug_messages.append(message)
        else:
            logging.debug(message)

    @classmethod
    def _parameters_generating(cls):
        # Required __tasklet_input__ means absence of separate parameters for each field of proto message
        return not cls.Parameters.__tasklet_input__.required

    def _get_tasklet_input(self):
        input_message = self.__holder_cls__.Input()
        cls = type(self)

        if self.Parameters.__tasklet_input__ != cls.Parameters.__tasklet_input__.default:
            self._debug("Initialize parameters using '__tasklet_input__'")
            json_format.ParseDict(self.Parameters.__tasklet_input__, input_message, ignore_unknown_fields=True)

        if not self._parameters_generating():
            return input_message

        params_dict = {}
        for name in cls.Parameters.TaskletInputParameters.names:
            if name == "__tasklet_input__":
                continue
            pcls = getattr(cls.Parameters, name)
            value = getattr(self.Parameters, name)
            if not pcls.__output__ and pcls.default != value:
                params_dict[name] = value
        self._debug("Parameters update dict: {!r}".format(params_dict))
        fields_not_found = convert.flat_dict_to_proto(params_dict, input_message)

        if fields_not_found:
            self._debug("Fields not found in tasklet input: {}".format(fields_not_found))
        return input_message

    def _set_input(self, any_input):
        tasklet_input = self.__holder_cls__.Input()
        any_input.Unpack(tasklet_input)

        if not self.Context.initial_input_parameters:
            self.Context.initial_input_parameters = self.Parameters.__tasklet_input__
        self.Parameters.__tasklet_input__ = json_format.MessageToDict(tasklet_input, preserving_proto_field_name=True)

        if self._parameters_generating():
            params = convert.proto_to_flat_dict(tasklet_input)
            for name, value in params.items():
                setattr(self.Parameters, name, value)

    @property
    def _tasklet_impl_name(self):
        if self.Parameters.__tasklet_name__:
            return self.Parameters.__tasklet_name__

        dispatch.initialize_tasklet_registry()
        tasklet_name = self.__holder_cls__.name
        impl_names = dispatch.name_to_impl(tasklet_name)
        if not impl_names:
            raise ValueError("Unknown tasklet type {}".format(tasklet_name))
        if len(impl_names) != 1:
            raise ValueError(
                "Uncertainty, found more than one implementation of tasklet {}: {}",
                tasklet_name, impl_names,
            )
        logging.debug("Use specific implementation '%s' for tasklet '%s'", impl_names[0], tasklet_name)
        return impl_names[0]

    def on_create(self):
        if self.Parameters.__run_id__ is None:
            self.Parameters.__run_id__ = utils.generate_run_id()
        if self.Parameters.__tasklet_id__ is None:
            self.Parameters.__tasklet_id__ = utils.generate_tasklet_id()

    def on_save(self):
        self.__server_side = True
        job_spec = tasklet_pb2.JobStatement()
        job_spec.name = self._tasklet_impl_name

        tasklet_input = self._get_tasklet_input()
        job_spec.input.Pack(tasklet_input)

        request_data = job_spec.SerializeToString()

        # Execute tasklet code.
        job_spec.MergeFromString(dispatch.get_init_description(request_data))

        self._set_input(job_spec.input)

        requirements = job_spec.requirements
        if requirements.disk:
            self.Requirements.disk_space = requirements.disk
        if requirements.cpu:
            self.Requirements.cores = requirements.cpu
        if requirements.ram:
            self.Requirements.ram = requirements.ram
        if requirements.tmpfs:
            self.Requirements.ramdrive = ctm.RamDrive(ctm.RamDriveType.TMPFS, requirements.tmpfs, None)

        for field in requirements.sandbox.DESCRIPTOR.fields:
            requirement_value = getattr(requirements.sandbox, field.name)
            if requirement_value:
                setattr(self.Requirements, field.name, requirement_value)

    def on_execute(self):
        job = tasklet_pb2.JobInstance()
        job.statement.name = self._tasklet_impl_name
        job.id = self.Parameters.__tasklet_id__
        job.run_id = self.Parameters.__run_id__
        # job.statement.requirements is not required here since it has already used in task requirements.

        holder = self.__holder_cls__
        job.statement.ctx.CopyFrom(rt_context.setup(tasklet_pb2.Domain.SANDBOX, holder.Context.DESCRIPTOR))

        tasklet_input = self._get_tasklet_input()
        job.statement.input.Pack(tasklet_input)

        request_data = job.SerializeToString()

        logging.debug("Tasklet job request:\n%s", json_format.MessageToJson(job))

        # Execute tasklet code.
        result_data = dispatch.dispatch(request_data)

        result = tasklet_pb2.JobResult()
        result.ParseFromString(result_data)
        logging.debug("Tasklet job result:\n%s", json_format.MessageToJson(result))

        if not result.success:
            self.set_info("Tasklet failed with error: {}".format(result.error))
            if result.is_python_error:
                self.set_info(
                    "\n".join(traceback.format_exception(*py_base.exception_tuple(result)))
                )

        self.Parameters.tasklet_result = json_format.MessageToJson(result, preserving_proto_field_name=True)
        output = self.__holder_cls__.Output()
        result.output.Unpack(output)
        self.Parameters.__tasklet_output__ = json.loads(
            json_format.MessageToJson(output, preserving_proto_field_name=True)
        )

        cls = type(self)
        # if json parameter is required we shouldn't expose output to other parameters
        # it could be even impossible to do that because output is too complex message
        if not cls.Parameters.__tasklet_output__.required:
            params = convert.proto_to_flat_dict(output)
            for field_name, value in params.items():
                logging.debug("%s = %s", field_name, value)
                pcls = getattr(cls.Parameters, field_name, None)
                if pcls is not None:
                    if pcls.__output__:
                        setattr(self.Parameters, field_name, value)
                    else:
                        logging.error("Trying to set non-output parameter '%s'", field_name)

        if not result.success:
            raise sb_errors.TaskFailure(
                "Tasklet job has failed {}".format(
                    "due to python tasklet exception (see above)"
                    if result.is_python_error else
                    "with error: {}".format(result.error)
                )
            )
