from __future__ import absolute_import

import sys
import uuid
import json
import inspect
import logging
import datetime as dt

from sandbox.common import rest
from sandbox.common import system
from sandbox.common import encoding
from sandbox.common import api as common_api
from sandbox.common import math as common_math


__all__ = ("MemoizeCreator", "NoTimeout", "ProgressMeter", "set_parameter")


class NoTimeout(object):
    """
    If the current task goes to TIMEOUT state, switch it to TEMPORARY instead
    """
    def __init__(self):
        self.client = rest.Client()

    def set_timeout(self, value):
        self.client.task.current.context.value.update(key="__no_timeout", value=value)

    def __enter__(self):
        logging.debug("Entering no-timeout context")
        self.set_timeout(True)

    def __exit__(self, *_):
        logging.debug("Exiting no-timeout context")
        self.set_timeout(False)


class MemoizeStage(object):
    """
    Pure magic, for the organization stages to be executed only once. Use in method `on_execute`.

    Examples:

        .. code-block:: python

            with self.memoize_stage.important_action:
                # <code will be executed only once>

            with self.memoize_stage.important_action(3):
                # <code will be executed a maximum of three times>

        Where `important_action` is stage name and can be any.

        By default section marks as executed on entrance. You may change this behavior:

        .. code-block:: python

            with self.memoize_stage.interrupted_section(commit_on_entrance=False):
                # <if code raise exception then section will not be marked as executed>

    """

    SkipStageException = Exception()

    def __init__(self, task, stage_name):
        self.__task = task
        self.__stage_name = stage_name
        self.__exec_key = "__stage_{}_exec_count__".format(self.__stage_name)
        self.__skip_key = "__stage_{}_skip_count__".format(self.__stage_name)
        self.__max_runs = 1
        self.__executed = None
        self.__commit_on_entrance = True
        self.__commit_on_wait = True
        self.__logger = logging.getLogger(type(self).__name__)

    @property
    def runs(self):
        return getattr(self.__task.Context, self.__exec_key) or 0

    @property
    def passes(self):
        return getattr(self.__task.Context, self.__skip_key) or 0

    @property
    def executed(self):
        return self.__executed

    def __inc_key(self, key):
        value = (getattr(self.__task.Context, key) or 0) + 1
        setattr(self.__task.Context, key, value)
        # noinspection PyProtectedMember
        if self.__task._sdk_server._external_auth:
            # noinspection PyProtectedMember
            self.__task._sdk_server.task.current.context.value.update(key=key, value=value)

    def __call__(self, max_runs=None, commit_on_entrance=None, commit_on_wait=None):
        if max_runs is not None:
            self.__max_runs = int(max_runs)
        if commit_on_entrance is not None:
            self.__commit_on_entrance = bool(commit_on_entrance)
        if commit_on_wait is not None:
            self.__commit_on_wait = bool(commit_on_wait)
        return self

    def __enter__(self):
        skip = self.runs >= self.__max_runs
        self.__executed = not skip
        if skip:
            self.__logger.info("Skipping stage '%s'", self.__stage_name)
            self.__inc_key(self.__skip_key)
            # OMG! Do some magic
            sys.settrace(lambda *args, **keys: None)
            frame = inspect.currentframe().f_back
            frame.f_trace = self.trace
        else:
            self.__logger.info("Entering stage '%s'", self.__stage_name)
            if self.__commit_on_entrance:
                self.__inc_key(self.__exec_key)

        return self

    def trace(self, frame, event, arg):
        raise self.SkipStageException

    # noinspection PyUnusedLocal
    def __exit__(self, exc_type, exc_value, traceback):
        from sandbox import sdk2

        self.__logger.info("Exiting stage '%s'", self.__stage_name)
        if exc_value is None or isinstance(exc_value, sdk2.task.Wait):
            if not self.__commit_on_entrance and (exc_value is None or self.__commit_on_wait):
                self.__inc_key(self.__exec_key)
        elif exc_value is self.SkipStageException:
            # suppress the SkipStageException
            return True


class MemoizeCreator(object):
    def __init__(self, task):
        self.__task = task

    def __getattr__(self, name):
        return MemoizeStage(self.__task, name)

    def __getitem__(self, name):
        return MemoizeStage(self.__task, name)


class ProgressMeter(object):
    """
    Lets your task be more verbose about lengthy actions going inside it.

    The class is supposed to use in the manner of a context manager inside your task
    during client-side execution stages (`on_prepare()`, `on_execute()`, etc):

    .. code-block:: python

        start, end = 1, 11
        with sandbox.sdk2.helpers.ProgressMeter(
            "Downloading movies", minval=start, maxval=end,
            escape=False, formatter=lambda value: "<b>{}</b> file(s)".format(value)
        ) as meter:
            for _ in xrange(end - start):
                meter.add(1)

    The above will display on the task's page these lines,
    with the current value and percentage updating on page refresh:

    .. code-block:: python

        Downloading movies; <b>1</b> file(s) / <b11</b> file(s) completed (0%)

    What's more exciting is that you can nest progress meters and they will be sorted accordingly on a task's page:

    .. code-block:: python

        import logging

        i, j = 3, 10
        with sandbox.sdk2.helpers.ProgressMeter("Fetching updates", maxval=i) as outer:
            for _ in xrange(i):
                outer.add(1)
                with sandbox.sdk2.helpers.ProgressMeter("Completing fake stages (may take a while)...") as inner:
                    for _ in xrange(j):
                        inner.add(i * j)
                        logging.info(
                            "Downloaded %d parts of a fake package (started at %s)",
                            inner.value, inner.start_time
                        )
    """

    UPDATE_PERIOD = dt.timedelta(seconds=0.1)
    UPDATEABLE_FIELDS = ("message", "minval", "maxval")

    def __init__(self, message, minval=0, maxval=None, escape=True, formatter=encoding.force_unicode):
        """
        :param message: a string which will precede progress report
        :param minval: value to start counting progress from; may be negative
        :param maxval: value considered as 100%; omit to skip progress percentage
        :param escape: enable HTML entities escaping
        :param formatter: a callable which accepts a single argument (the value) and returns a string
        """

        from sandbox import sdk2

        self.escape = escape
        self.message = message
        self.minval = minval
        self.maxval = maxval
        self.formatter = formatter

        self.id = uuid.uuid4().hex
        self.start_time = common_api.DateTime.encode(dt.datetime.utcnow())
        self.value = minval

        self.__active = False  # only context manager usage is allowed, so that everything is cleared up on scope exit
        self.__last_update_time = None
        self.__tid = sdk2.Task.current.id
        self.__agentr = sdk2.Task.current.agentr

        logging.debug("%r created", self)

    @property
    def message(self):
        return self.__message

    @message.setter
    def message(self, value):
        self.__message = self.__maybe_escape(value) if value is not None else ""

    @property
    def minval(self):
        return self.__minval

    @minval.setter
    def minval(self, value):
        self.__minval = value if value is not None else 0

    def __repr__(self):
        return "ProgressMeter(id={}, task={}, message={})".format(self.id, self.__tid, self.message)

    def __maybe_escape(self, message):
        return encoding.escape(message) if self.escape else message

    def __enter__(self):
        logging.debug("[STARTED] %r", self)
        self.__active = True
        self.__maybe_update()
        return self

    def __exit__(self, *_, **__):
        from sandbox.agentr import errors as aerrors

        try:
            self.__agentr.progress_meta.delete(self.id)
        except aerrors.ARException as exc:
            logging.error("Failed to remove %r from AgentR: %s", self, exc)

        logging.debug("[FINISHED] %r", self)
        self.__active = False

    def add(self, value):
        self.value += value
        self.__maybe_update()

    def update(self, **kwargs):
        for field in self.UPDATEABLE_FIELDS:
            if field in kwargs:
                setattr(self, field, kwargs[field])
        if kwargs:
            self.__maybe_update()

    @property
    def current(self):
        from sandbox.agentr import types as atypes

        current = self.__maybe_escape(self.formatter(self.value))
        if None in (self.minval, self.maxval):
            total, percentage = None, None
        else:
            total = self.__maybe_escape(self.formatter(self.maxval))
            percentage = common_math.progress(self.minval, self.value, self.maxval)

        return atypes.Action(
            message=self.message,
            started=self.start_time,
            current=current,
            total=total,
            percentage=percentage
        )

    def __maybe_update(self):
        from sandbox.agentr import errors as aerrors

        utcnow = dt.datetime.utcnow()
        if (
            (self.__last_update_time and utcnow - self.__last_update_time < self.UPDATE_PERIOD) or
            not self.__active
        ):
            return

        try:
            self.__agentr.progress_meta.insert(self.id, self.current)
            self.__last_update_time = utcnow
        except aerrors.ARException as exc:
            logging.error("Failed to update %r: %s", self, exc)


class SingleFileStorage(object):
    __instances = {}
    __initialized = False

    def __new__(cls, file_path, autoload=False):
        obj = cls.__instances.get(file_path)
        if obj is None:
            obj = super(SingleFileStorage, cls).__new__(cls)
            cls.__instances[file_path] = obj
        return obj

    def __init__(self, file_path, autoload=False):
        if self.__initialized:
            return
        self.__initialized = True
        self._values = dict()
        self._file_path = file_path
        self._loaded = False
        if autoload:
            self.load()
            self._loaded = True

    @property
    def loaded(self):
        """ returns if the load() function runs successfully at least once """
        return self._loaded

    @system.skip_if_binary
    def load(self):
        with open(self._file_path, "r") as file:
            self._values = json.load(file)
        self._loaded = True

    def save(self):
        with open(self._file_path, "w+") as file:
            json.dump(self._values, file)

    def __setitem__(self, key, value):
        self._values[key] = value

    def __getitem__(self, key):
        return self._values.get(key)


def set_parameter(name, parameter):
    """
    Helper function for dynamically definition of task parameters

    Example:
    .. code-block:: python
        from sandbox import sdk2

        class Parameters(sdk2.Parameters):
            for name in ("a", "b", "c"):
                sdk2.helpers.set_parameter(name, sdk2.parameters.String(name.capitalize()))

    :param name: parameter name
    :param parameter: parameter object
    """
    frame = inspect.currentframe().f_back
    names = list(frame.f_locals.setdefault("__names__", frame.f_code.co_names))
    names.append(name)
    frame.f_locals["__names__"] = tuple(names)
    frame.f_locals[name] = parameter
