# -*- coding: utf-8 -*-

import collections
import copy
import datetime
import logging
import sys

import gevent
import greenlet
from passport.backend.social.common.chrono import now


logger = logging.getLogger(__name__)
GreenTraceLine = collections.namedtuple('GreenTraceLine', 'timestamp filename lineno')


def format_trace(trace):
    lines = list()
    for line in trace:
        lines.append(
            str(datetime.datetime.fromtimestamp(line.timestamp)) +
            ' ' +
            line.filename +
            ' ' +
            str(line.lineno)
        )
    return '\n'.join(lines)


class GreenTracerError(Exception):
    pass


class GreenTracerDispatcher(object):
    debug = False

    def __init__(self):
        self._greenlet_to_trace = dict()
        self._greenlet_tracer_enabled = False
        self._python_tracer_enabled = False
        self._filters = list()

    def add_greenlet(self, greenlet):
        if self._is_traced_greenlet(greenlet):
            raise GreenTracerError()

        self._create_greenlet_trace(greenlet)

        if (
            self._is_greenlet_tracer_enabled and
            self._get_current_greenlet() == greenlet and
            not self._is_python_tracer_enabled()
        ):
            self._enable_python_tracer()

    def remove_greenlet(self, greenlet):
        if not self._is_traced_greenlet(greenlet):
            raise GreenTracerError()

        if (
            self._get_current_greenlet() == greenlet and
            self._is_python_tracer_enabled()
        ):
            self._disable_python_tracer()

        self._remove_greenlet_trace(greenlet)

    def get_greenlet_trace(self, greenlet):
        if not self._is_traced_greenlet(greenlet):
            raise GreenTracerError()

        return self._get_greenlet_trace(greenlet)

    def enable(self):
        if not self._is_greenlet_tracer_enabled():
            self._enable_greenlet_tracer()

        greenlet = self._get_current_greenlet()
        if self._is_traced_greenlet(greenlet):
            if not self._is_python_tracer_enabled():
                self._enable_python_tracer()

    def disable(self):
        if self._is_python_tracer_enabled():
            self._disable_python_tracer()

        if self._is_greenlet_tracer_enabled():
            self._disable_greenlet_tracer()

    def add_module_filter(self, substr):
        self._filters.append(substr)

    def is_traced_greenlet(self, greenlet):
        return self._is_traced_greenlet(greenlet)

    def _is_greenlet_tracer_enabled(self):
        return self._greenlet_tracer_enabled

    def _enable_greenlet_tracer(self):
        if self.debug:
            logger.debug('Enable greenlet tracer')
        self._greenlet_tracer_enabled = True
        greenlet.settrace(self._trace_greenlets)

    def _disable_greenlet_tracer(self):
        if self.debug:
            logger.debug('Disable greenlet tracer')
        greenlet.settrace(None)
        self._greenlet_tracer_enabled = False

    def _is_traced_greenlet(self, greenlet):
        return greenlet in self._greenlet_to_trace

    def _create_greenlet_trace(self, greenlet):
        if self.debug:
            logger.debug('New traced greenlet %r' % greenlet)
        self._greenlet_to_trace[greenlet] = list()

    def _remove_greenlet_trace(self, greenlet):
        del self._greenlet_to_trace[greenlet]

    def _get_greenlet_trace(self, greenlet):
        return self._greenlet_to_trace[greenlet]

    def _append_line_to_trace(self, line):
        greenlet = self._get_current_greenlet()
        trace = self._get_greenlet_trace(greenlet)
        trace.append(line)

    def _get_current_greenlet(self):
        return gevent.getcurrent()

    def _is_python_tracer_enabled(self):
        return self._python_tracer_enabled

    def _enable_python_tracer(self):
        self._python_tracer_enabled = True
        sys.settrace(self._trace_python)

    def _disable_python_tracer(self):
        sys.settrace(None)
        self._python_tracer_enabled = False

    def _trace_greenlets(self, event, args):
        if event == 'switch':
            origin, target = args
            if self._is_traced_greenlet(target):
                if not self._is_python_tracer_enabled():
                    self._enable_python_tracer()
            else:
                if self._is_python_tracer_enabled():
                    self._disable_python_tracer()

    def _in_filters(self, filename):
        for substr in self._filters:
            if substr in filename:
                return True

    def _trace_python(self, frame, event, arg):
        if event == 'line':
            filename = frame.f_code.co_filename
            if self._in_filters(filename):
                line = GreenTraceLine(
                    timestamp=now.f(),
                    filename=filename,
                    lineno=frame.f_lineno,
                )
                self._append_line_to_trace(line)
        return self._trace_python


class TreeGreenTracerDispatcher(GreenTracerDispatcher):
    def __init__(self):
        super(TreeGreenTracerDispatcher, self).__init__()
        self._tree = _Tree()

    def add_greenlet(self, greenlet, parent=None):
        if greenlet in self._tree:
            raise GreenTracerError()

        if parent:
            if not self._is_traced_greenlet(parent):
                raise GreenTracerError()
            if parent not in self._tree:
                raise GreenTracerError()

        super(TreeGreenTracerDispatcher, self).add_greenlet(greenlet)

        self._tree.add(greenlet, parent)

    def remove_greenlet(self, greenlet):
        if greenlet not in self._tree:
            raise GreenTracerError()

        descendants = self._tree.get_descendants(greenlet)

        for descendant in descendants:
            self._tree.remove(descendant)
            super(TreeGreenTracerDispatcher, self).remove_greenlet(descendant)

        self._tree.remove(greenlet)
        super(TreeGreenTracerDispatcher, self).remove_greenlet(greenlet)

    def get_greenlet_trace(self, greenlet):
        if greenlet not in self._tree:
            raise GreenTracerError()

        descendants = self._tree.get_descendants(greenlet)

        trace = super(TreeGreenTracerDispatcher, self).get_greenlet_trace(greenlet)
        trace = copy.copy(trace)

        for descendant in descendants:
            descendant_trace = super(TreeGreenTracerDispatcher, self).get_greenlet_trace(descendant)
            trace.extend(descendant_trace)

        return trace


class TreeError(Exception):
    pass


class _Tree(object):
    def __init__(self):
        self._parent_to_child = {None: set()}
        self._child_to_parent = dict()

    def get_descendants(self, node):
        descendants = list()
        queue = [node]
        while queue:
            node = queue.pop(0)
            children = self._parent_to_child[node]
            for child in children:
                descendants.insert(0, child)
                queue.append(child)
        return descendants

    def add(self, node, parent=None):
        if node in self._parent_to_child:
            raise TreeError()
        if parent not in self._parent_to_child:
            raise TreeError()
        if node in self._parent_to_child[parent]:
            raise TreeError()

        self._parent_to_child[parent].add(node)
        self._parent_to_child[node] = set()
        self._child_to_parent[node] = parent

    def remove(self, node):
        if node is None:
            raise TreeError()

        descendants = self.get_descendants(node)
        for descendant in descendants:
            assert not self._parent_to_child[descendant]
            del self._child_to_parent[descendant]
            del self._parent_to_child[descendant]

        parent = self._child_to_parent[node]
        del self._child_to_parent[node]
        self._parent_to_child[parent].remove(node)
        del self._parent_to_child[node]

    def __contains__(self, node):
        return node in self._parent_to_child


dispatcher = TreeGreenTracerDispatcher()
