from __future__ import absolute_import

import sys
import traceback
import inspect

import six


__all__ = [
    'raiseEx', 'formatException', 'saveTraceback', 'getTraceback', 'setTraceback'
]

_tracebackAttr = '_traceback'
_tracebackAttr = six.moves.intern(_tracebackAttr)


def _makeError(objectType, args, state):
    err = objectType.__new__(objectType)
    err.args = args
    err.__setstate__(state)
    return err


def sFail(msg=None):
    sAssert(False, msg, 2)


def sAssert(condition, msg=None, unwindFrames=1):
    # TODO: Remove #unwindFrames from sys.exc_info (All members of traceback marked as RO)
    if not condition:
        frameInfo = inspect.currentframe()
        for _ in range(unwindFrames):
            frameInfo = frameInfo.f_back if frameInfo else None
        frameInfo = inspect.getframeinfo(frameInfo) if frameInfo else None
        if not msg:
            msg = frameInfo.code_context if frameInfo else None
            msg = ' '.join(msg).strip() if msg else ''
        if not frameInfo:
            raise AssertionError(msg)
        else:
            raise AssertionError(
                'Assertion `{0}` in {1} {2}:{3}'.format(
                    msg, frameInfo.function, frameInfo.filename, frameInfo.lineno
                ) if frameInfo else ''
            )


class SkynetBaseError(Exception):
    """
    Base class for all skynet errors!

    The purpose for this thing is:
     - it IS picklable/unpicklable even if there are some variables inside.
       That means what even if we have separate .message variable in this class
       it *will* be reconstructed properly after unpickling. _makeError() wrapper and
       __reduce__() are responsible for this
     - it has predefined standartized __str__ and __repr__ methods :)
     - it strictly show what using kwargs with Exception class can lead to errors. But
       you still can freely set class variables and use keyword arguments in descendant
       classes if you wish.

    So, good practices are:
     - inherit from this class and use args, kwargs as you wish
     - simple args can be passed to super() method, so they *will* be included in str() and repr()
       results
     - there are one "magic" kwargs -- "message". It will always be displayed in str() and repr()
       results
     - if you want to pass additional stuff to your exception, but you dant want it will be ether printed
       by usual tools (e.g. formatException() or just str() or repr()) -- set argument to your class and *dont*
       pass it to base as argument. See example below
     - **NEVER** use super() in descendants
     - message will be tried to formatted first via '%' .args, next with .format(.args, .kwargs) and finnaly plain
       message will be specified

    >>> class SuperError(SkynetBaseError):
    ...     def __init__(self, message, user, aLotOfData):
    ...         # We will pass simple arguments to this class, but
    ...         # final exception will have only user name in its .args, and only user
    ...         # will be displayed while printing exception.
    ...         # aLotOfData can be accessed via err.aLotOfData only
    ...         SkynetBaseError.__init__(self, user, message=message, aLotOfData=aLotOfData)
    >>>
    >>> err = SuperError('good man!', 'mocksoul', 'fuck off'*1024)
    >>> assert err.args == ('mocksoul', )
    >>> assert err.message == 'good man!'
    >>> assert err.aLotOfData == 'fuck off'*1024
    >>> print err
    'good man!'
    >>> print repr(err)
    SuperError('mocksoul', message='good man!')


    """

    def __init__(self, *args, **kwargs):
        assert 'args' not in kwargs

        message = kwargs.pop('message', '')

        Exception.__init__(self, *args)
        self.__dict__.update(kwargs)
        self.message = message

    def __str__(self):
        try:
            result = self.message if self.message else ''
            if len(self.args) > 0:
                try:
                    result = result % tuple(self.args)
                except:
                    try:
                        result2 = result.format(*self.args, **self.__dict__)
                        if result2 == result:
                            raise Exception()
                        else:
                            result = result2
                    except:
                        if result:
                            result += '; '
                        result += 'args: %r' % (self.args, )

            return result
        except Exception:
            return 'CANT FORMAT EXCEPTION (%s: message:%s, args:%s)!! Use .args to examine!' % (
                self.__class__.__name__, self.message, self.args
            )

    def __repr__(self):
        lst = ['%r' % arg for arg in self.args]
        lst.append('message=%r' % self.message)

        return '%s(%s)' % (self.__class__.__name__, ', '.join(lst))

    def __reduce__(self, *args, **kwargs):
        return _makeError, (type(self), self.args, self.__dict__)


class SkynetError(SkynetBaseError):
    """ Common skynet errors base class. """

    def __init__(self, message, *args, **kwargs):
        kwargs['message'] = message
        return SkynetBaseError.__init__(self, *args, **kwargs)


def _filterUtilErrors(trace):
    goodTrace = []
    for traceLine in trace:
        if __file__.rstrip('c') in traceLine:
            break
        goodTrace.append(traceLine)
    return goodTrace


def _getTraceback(err, blocktitle=None, excInfo=None):
    blocktitle = '  %s\n' % (' %s ' % blocktitle).center(68, '-') if blocktitle else ('  ' + '-' * 68 + '\n')
    excInfo = excInfo or sys.exc_info()

    if all(excInfo) and excInfo[1] is err:
        trace = _filterUtilErrors(traceback.format_exception(*excInfo)[1:-1])
        trace = [blocktitle] + trace
    else:
        trace = []

    return trace + getattr(err, _tracebackAttr, [])


def raiseEx(err, chainedErr, blocktitle=None):
    """
    Extended exception raising, can "chain" to other exception.
    """

    excInfo = sys.exc_info()
    if excInfo[1] is chainedErr:
        saveTraceback(chainedErr, blocktitle)

    if not hasattr(err, _tracebackAttr):
        setattr(err, _tracebackAttr, [])
    if not hasattr(chainedErr, _tracebackAttr):
        setattr(chainedErr, _tracebackAttr, [])

    setattr(err, _tracebackAttr, getattr(chainedErr, _tracebackAttr) + getattr(err, _tracebackAttr))

    raise err


def formatException(err=None, asList=False, excInfo=None):
    """
    Format current exception and return.

    In general, this works the same as traceback.format_exc() does. But, with few differences:

    1) if error was chained with other (with saveTraceback() or raiseEx()) -- it will display
       nested exceptions as well

    2) you can pass error to this function, so the traceback will be printed from
       the saved in error (with saveTraceback() generally)

    3) it can output traceback as list with aslist=True.
    """

    if not err:
        if not excInfo:
            excInfo = sys.exc_info()
        assert all(excInfo), 'Exception info not available'

        trace = traceback.format_exception(*excInfo)
        trace = trace[:1] + _filterUtilErrors(trace[1:-1]) + trace[-1:]
        err = excInfo[1]
    else:
        trace = [
            'Traceback (most recent call last):\n',
            '%s: %s\n' % (type(err).__name__, err)
        ]

    chainedTrace = getattr(err, _tracebackAttr, None)

    if chainedTrace:
        trace = trace[:-1] + chainedTrace[0 if len(trace) > 2 else 1:] + trace[-1:]

    if asList:
        return trace

    return ''.join(trace)


def saveTraceback(err, blocktitle=None, excInfo=None):
    """
    Save current traceback to current exception

    This should be used *only* while handling exception in except: block
    It will save exception traceback directly in error, so it can be printed
    later with formatException() from other place.
    """
    setattr(err, _tracebackAttr, _getTraceback(err, blocktitle, excInfo))
    # sys.exc_clear()


def getTraceback(err):
    return getattr(err, _tracebackAttr, None)


def setTraceback(err, traceback):
    if traceback is None:
        if hasattr(err, _tracebackAttr):
            delattr(err, _tracebackAttr)
    else:
        setattr(err, _tracebackAttr, traceback)
