"""
Our wrappers for gevent.greenlet
"""

from __future__ import absolute_import, print_function

import gevent

from collections import defaultdict

import logging
import types

from ..config import Config
from ..sys.gettime import threadCPUTime as timeFunc
from ..errors import formatException, saveTraceback
from ..functional import wraps

cfg = Config(
    'ya.skynet.util.gevent',
    searchPackages=['kernel.util', 'ya.skynet.util']
).Greenlet


class SwitchableContextMixin(object):
    """
    Mixin that makes greenlet to exit and enter contexts
    on greenlet switches
    """
    def __init__(self, *args, **kwargs):
        self._contexts = []
        super(SwitchableContextMixin, self).__init__(*args, **kwargs)

    def registerContext(self, ctx):
        if ctx not in self._contexts:
            self._contexts.append(ctx)

    def unregisterContext(self, ctx):
        if ctx in self._contexts:
            self._contexts.remove(ctx)

    def _superSwitch(self, *args):
        return super(SwitchableContextMixin, self).switch(*args)

    def switch(self, *args):
        for ctx in self._contexts:
            ctx.__enter__()
        return self._superSwitch(*args)

    def _superSwitchOut(self):
        try:
            return super(SwitchableContextMixin, self).switch_out()
        except AttributeError:
            pass

    def switch_out(self):
        contexts = self._contexts
        self._contexts = []

        for ctx in reversed(contexts):
            ctx.__exit__(None, None, None)

        self._contexts = contexts

        return self._superSwitchOut()


def makeGreenletContextAware(let=None):
    if let is None:
        let = gevent.getcurrent()
    if isinstance(let, type):
        # patching class
        if SwitchableContextMixin in let.__bases__:
            return
        import sys
        module = let.__module__
        name = let.__name__
        sys.modules[module].__dict__[name] = type(name, (SwitchableContextMixin, let) + let.__bases__, {})
    else:
        # patching object
        try:
            if SwitchableContextMixin in let.__class__.__bases__:
                return
            let.__class__ = type(let.__class__.__name__, (SwitchableContextMixin, let.__class__) + let.__class__.__bases__, {})
        except TypeError:  # Not a heap type
            let.registerContext = types.MethodType(SwitchableContextMixin.registerContext.__func__, let)
            let.unregisterContext = types.MethodType(SwitchableContextMixin.unregisterContext.__func__, let)

            def _superSwitch(self, *args):
                return _superSwitch._oldSwitch(self, *args)
            _superSwitch._oldSwitch = let.__class__.switch
            let._superSwitch = types.MethodType(_superSwitch, let)
            let.switch = types.MethodType(SwitchableContextMixin.switch.__func__, let)

            def _superSwitchOut(self, *args):
                return _superSwitchOut._oldSwitchOut(self, *args)
            try:
                _superSwitchOut._oldSwitchOut = let.__class__.switch_out
            except AttributeError:
                _superSwitchOut._oldSwitchOut = lambda self: None
            let._superSwitchOut = types.MethodType(_superSwitchOut, let)
            let.switch_out = types.MethodType(SwitchableContextMixin.switch_out.__func__, let)
        finally:
            let._contexts = []


class Greenlet(SwitchableContextMixin, gevent.greenlet.Greenlet):
    """
    Class which improve gevent.Greenlet with following:
    1. Logging of exceptions
    2. Exceptions prepossessing
    """

    log = logging.getLogger('ya.skynet.util.gevent')
    """
    Logger for exceptions
    Subclasses could set own logger here
    """

    NoLogExceptions = (
        gevent.GreenletExit,
    )
    """
    Exceptions which will not be logged
    Subclasses could add such exceptions here
    """

    _classInitStamp = timeFunc()

    __profileResults = defaultdict(lambda: [0, 0])

    @classmethod
    def printProfileResults(cls):
        results = [(v[0], v[1], k) for (k, v) in cls.__profileResults.iteritems()]
        results.sort()
        total = 0

        formatString = '{0:80} {1:.4f} {2}'

        for time, calls, name in results:
            print('{0:80} {1:.4f} {2}'.format(name, time, calls))
            total += time

        print(formatString.format('TOTAL', total, ''))
        print(formatString.format('APP', timeFunc() - cls._classInitStamp, ''))

    def switch(self, *arg):
        if cfg.Debug.Profiling:
            self._switchInStamp = timeFunc()

        return super(Greenlet, self).switch(*arg)

    def switch_out(self):
        if cfg.Debug.Profiling:
            stamp = timeFunc()
            Greenlet.__profileResults[str(self)][0] += (stamp - self._switchInStamp)
            Greenlet.__profileResults[str(self)][1] += 1
            self._switchInStamp = None

        try:
            super(Greenlet, self).switch_out()
        except AttributeError:
            pass

    def __init__(self, *args, **kwargs):
        super(Greenlet, self).__init__(*args, **kwargs)
        if cfg.Debug.Profiling:
            self._switchInStamp = None
        if getattr(self, '_run', None) is not None:
            self._run = self._safe_run(self._run)

    def processException(self, err):
        """
        Called on each exception which terminates green let
        If tis function raises exception will not be logged
        :param BaseException err:
        """
        if isinstance(err, self.NoLogExceptions):
            raise

    def _safe_run(self, func):
        @wraps(func)
        def _wrapper(*args, **kwargs):
            try:
                if cfg.Debug.Profiling:
                    self._switchInStamp = timeFunc()
                return func(*args, **kwargs)
            except BaseException as err:
                # Dont log some types of errors
                self.processException(err)

                if not getattr(err, '_logged', False):
                    self.log.error('Unhandled exception in {0}: {1}'.format(self, formatException()))
                    err._logged = True
                else:
                    self.log.error('Unhandled exception in {0} : {1} (already reported)'.format(self, err))

                raise
            finally:
                if cfg.Debug.Profiling:
                    self.switch_out()

        return _wrapper

    def _report_error(self, exc_info):
        exception = exc_info[1]
        if isinstance(exception, gevent.GreenletExit):
            self._report_result(exception)
            return

        saveTraceback(exception, blocktitle='GreenLet._report_error()', excInfo=exc_info)
        self._exception = exception

        if self._links and self._notifier is None:
            # noinspection PyUnresolvedReferences
            self._notifier = gevent.core.active_event(self._notify_links)

if cfg.Debug.PrintProfileOnExit:
    from kernel.util.sys import atexit
    atexit.register(Greenlet.printProfileResults)

spawn = Greenlet.spawn
