"""
Profiler for gevent
"""
import os
import sys
import time
import gevent
import signal
import inspect

_gls = {}
_curr_gl = None
_states = {}
_curr_states = {}

_stats_output_file = sys.stdout
_summary_output_file = sys.stdout
_trace_output_file = sys.stdout

_print_percentages = False
_time_blocking = False

_attach_expiration = None

_trace_began_at = None


class _State:
    def __init__(self):
        self.modulename = None
        self.co_name = None
        self.filename = None
        self.line_no = None
        self.start_time = None
        self.full_class = None
        self.elapsed = 0.0
        self.depth = 0
        self.calls = []
        self.parent = None
        self.suffix = ''

    @property
    def name(self):
        return str(self)

    def __str__(self):
        first = self.modulename
        if self.full_class:
            true_class = self.full_class
            # use inspect to find the method's true class
            for cls in inspect.getmro(self.full_class):
                if self.co_name in cls.__dict__:
                    fnc = cls.__dict__[self.co_name]
                    if hasattr(fnc,
                               "func_code") and fnc.func_code.co_filename == self.filename and fnc.func_code.co_firstlineno == self.line_no:
                        true_class = cls
            first = "%s.%s" % (true_class.__module__, true_class.__name__)
        name = "%s.%s" % (first, self.co_name)
        if self.suffix:
            name = "%s %s" % (name, self.suffix)
        return name


def _modname(path):
    """Return a plausible module name for the path."""

    base = os.path.basename(path)
    filename, ext = os.path.splitext(base)
    return filename


def _globaltrace(frame, event, arg):
    global _curr_gl

    if _attach_expiration is not None and time.time() > _attach_expiration:
        detach()
        return

    gl = gevent.greenlet.getcurrent()
    if gl not in _states:
        _states[gl] = _State()
        _curr_states[gl] = _states[gl]

    if _curr_gl is not gl:
        if _curr_gl is not None:
            _stop_timing(_curr_gl)
        _curr_gl = gl
        _start_timing(_curr_gl)

    code = frame.f_code
    filename = code.co_filename
    if filename:
        modulename = _modname(filename)
        if modulename is not None:
            _print_trace("[%s] call: %s: %s\n" % (gl, modulename, code.co_name))
    state = _State()
    _curr_states[gl].calls.append(state)
    state.parent = _curr_states[gl]
    _curr_states[gl] = state

    state.modulename = modulename
    state.filename = filename
    state.line_no = code.co_firstlineno
    state.co_name = code.co_name
    state.start_time = time.time()
    if 'self' in frame.f_locals:
        state.full_class = type(frame.f_locals['self'])

    tracefunc = _getlocaltrace(state)
    state.localtracefunc = tracefunc

    if modulename == 'hub' and code.co_name == 'switch' and not _time_blocking:
        _stop_timing(gl)

    return tracefunc


def _getlocaltrace(state):
    def _localtrace(frame, event, arg):
        if _attach_expiration is not None and time.time() > _attach_expiration:
            detach()
            return

        if event == 'return':
            gl = gevent.greenlet.getcurrent()
            code = frame.f_code
            filename = code.co_filename
            modulename = None
            if filename:
                modulename = _modname(filename)
            if modulename is not None:
                _print_trace("[%s] return: %s: %s: %s\n" % (gl, modulename, code.co_name, code.co_firstlineno))
            if state.start_time is not None:
                state.elapsed += time.time() - state.start_time
            assert _curr_states[gl].parent is not None
            _curr_states[gl] = _curr_states[gl].parent
            return None

        return state.localtracefunc

    return _localtrace


def _stop_timing(gl):
    def _stop_timing_r(state):
        if state.start_time is not None:
            state.elapsed += time.time() - state.start_time
            state.start_time = None
        if state.parent is not None:
            _stop_timing_r(state.parent)

    if gl not in _curr_states:
        # if we're reattaching later, it's possible to call stop_timing
        # without a full set of current state
        return
    curr_state = _curr_states[gl]
    _stop_timing_r(curr_state)


def _start_timing(gl):
    def _start_timing_r(state):
        state.start_time = time.time()
        if state.parent is not None:
            _start_timing_r(state.parent)

    if gl not in _curr_states:
        # if we're reattaching later, it's possible to call start_timing
        # without a full set of current state
        return
    curr_state = _curr_states[gl]
    _start_timing_r(curr_state)


class _CallSummary(object):
    def __init__(self, name):
        self.name = name
        self.count = 0
        self.own_cumulative = 0.0
        self.children_cumulative = 0.0

    @property
    def cumulative(self):
        return self.own_cumulative + self.children_cumulative

    def add(self, own_time, children_time):
        self.count += 1
        self.own_cumulative += own_time
        self.children_cumulative += children_time


class _PrimarySummary(_CallSummary):
    def __init__(self, name):
        super(_PrimarySummary, self).__init__(name)
        self.subroutines = {}
        self.callers = {}


class _SubroutineSummary(_CallSummary):
    def __init__(self, name, primary):
        super(_SubroutineSummary, self).__init__(name)
        self.primary = primary


_CallerSummary = _SubroutineSummary


def _add_subroutine(primary_name, child_name, call_summaries):
    primary = call_summaries[primary_name]
    child_primary = call_summaries[child_name]
    subroutine = primary.subroutines.setdefault(
        child_name, _SubroutineSummary(child_name, child_primary)
    )
    return subroutine


def _add_caller(primary_name, parent_name, call_summaries):
    primary = call_summaries[primary_name]
    parent_primary = call_summaries[parent_name]
    caller = primary.callers.setdefault(
        parent_name, _CallerSummary(parent_name, parent_primary)
    )
    return caller


def _sum_calls(state, call_summaries, parent=None, anticycle=tuple()):
    if state.name in anticycle:
        # print 'CYCLE FOUND: {}'.format('->'.join(anticycle) + '->{}'.format(state.name))
        while state.name in anticycle:
            state.suffix += 'R'

    if state.name in call_summaries:
        call = call_summaries[state.name]
    else:
        call = _PrimarySummary(state.name)
        call_summaries[state.name] = call

    child_exec_time = 0.0
    for child in state.calls:
        child_exec_time += _sum_calls(child, call_summaries, state, anticycle + (state.name,))

    call.add(own_time=state.elapsed - child_exec_time, children_time=child_exec_time)
    if parent:
        caller = _add_caller(state.name, parent.name, call_summaries)
        caller.add(own_time=state.elapsed - child_exec_time, children_time=child_exec_time)

        subroutine = _add_subroutine(parent.name, state.name, call_summaries)
        subroutine.add(own_time=state.elapsed - child_exec_time, children_time=child_exec_time)
    else:
        caller = call.callers.setdefault('<spontaneous>', _CallerSummary('<spontaneous>', None))
        caller.add(own_time=state.elapsed - child_exec_time, children_time=child_exec_time)

    return state.elapsed


def _maybe_open_file(f):
    if f is None:
        return None
    else:
        return open(f, 'w')


def _maybe_write(output_file, message):
    if output_file is not None:
        output_file.write(message)


def _maybe_flush(f):
    if f is not None:
        f.flush()


def _print_trace(msg):
    _maybe_write(_trace_output_file, msg)


def _print_stats_header(header):
    _maybe_write(_stats_output_file, "%40s %5s %12s %12s %12s\n" % header)
    _maybe_write(_stats_output_file, "=" * 86 + "\n")


def _print_stats(stats):
    _maybe_write(_stats_output_file, "%40s %5d %12f %12f %12f\n" % stats)


def _print_state(state, depth=0):
    _maybe_write(_summary_output_file, "%s %s %f\n" % ("." * depth, str(state), state.elapsed))
    for call in state.calls:
        _print_state(call, depth + 2)


def _make_call_line(call_stats, duration, call_indices, is_primary=False):
    if call_stats.name == '<spontaneous>':
        return ["", "", "", "", "", call_stats.name]
    else:
        index = call_indices[call_stats.name]
        name = "%s [%d]" % (call_stats.name, index)
    return [
        "[%d]" % index if is_primary else "",
        "%6.2f" % (call_stats.cumulative * 100 / duration) if is_primary else "",
        "%12f" % call_stats.own_cumulative,
        "%12f" % call_stats.children_cumulative,
        "%8d" % call_stats.count,
        name,
    ]


def _print_output(duration):
    call_summaries = {}
    for gl in _states.keys():
        _sum_calls(_states[gl], call_summaries)

    call_list = []
    for name in call_summaries:
        cs = call_summaries[name]
        call_list.append((cs.cumulative, cs))
    call_list.sort(reverse=True)

    call_indices = {c.name: index for index, (_, c) in enumerate(call_list, start=1)}
    col_names = ["index", "% time", "self", "children", "called", "name"]

    output = []

    for _, c in call_list:

        for caller in c.callers.itervalues():
            output.append(_make_call_line(caller, duration, call_indices))
        output.append(_make_call_line(c, duration, call_indices, is_primary=True))
        for subroutine in c.subroutines.itervalues():
            output.append(_make_call_line(subroutine, duration, call_indices))
        output.append([])

    # max widths
    widths = [max([len(row[x]) for row in output if len(row) > x]) for x in xrange(len(output[0]))]
    widths[-1] = 0
    # build row strings
    fmt_out = [" ".join([x.ljust(widths[i]) for i, x in enumerate(row)]) for row in output]
    fmt_out = [" ".join([x.center(widths[i]) for i, x in enumerate(col_names)])] + fmt_out
    fmt_out = [row or '-' * sum(widths) for row in fmt_out]

    # write them!
    map(lambda x: _maybe_write(_stats_output_file, "%s\n" % x), fmt_out)

    _maybe_write(_stats_output_file, '\f\n')

    _maybe_flush(_stats_output_file)

    for gl in _states.keys():
        _maybe_write(_summary_output_file, "%s\n" % gl)
        _print_state(_states[gl])
        _maybe_write(_summary_output_file, "\n")
    _maybe_flush(_summary_output_file)


def attach(duration=0):
    """
    Start execution tracing
    Tracing will stop automatically in 'duration' seconds.  If duration is zero, the
    trace won't stop until detach is called.
    """
    global _attach_expiration
    global _trace_began_at
    if _attach_expiration is not None:
        # already attached
        return
    now = time.time()
    if duration != 0:
        _attach_expiration = now + duration
    _trace_began_at = now
    sys.settrace(_globaltrace)


def detach():
    """
    Finish execution tracing, print the results and reset internal state
    """
    global _gls
    global _curr_gl
    global _states
    global _curr_states
    global _attach_expiration
    global _trace_began_at

    # do we have a current trace?
    if not _trace_began_at:
        return

    duration = time.time() - _trace_began_at
    _attach_expiration = None
    sys.settrace(None)
    _maybe_flush(_trace_output_file)
    _print_output(duration)
    _gls = {}
    _curr_gl = None
    _states = {}
    _curr_states = {}
    _trace_began_at = None


def profile(func, *args, **kwargs):
    """
    Takes a function and the arguments to pass to that function and runs it
    with profiling enabled.  On completion of that function, the profiling
    results are printed.  The return value of the profiled method is then
    returned.
    """
    sys.settrace(_globaltrace)
    trace_began_at = time.time()
    retval = func(*args, **kwargs)
    sys.settrace(None)
    _maybe_flush(_trace_output_file)
    _print_output(time.time() - trace_began_at)

    return retval


def set_stats_output(f):
    """
    Takes a filename and will write the call timing statistics there
    """
    global _stats_output_file
    _stats_output_file = _maybe_open_file(f)


def set_summary_output(f):
    """
    Takes a filename and will write the execution summary there
    """
    global _summary_output_file
    _summary_output_file = _maybe_open_file(f)


def set_trace_output(f):
    """
    Takes a filename and writes the execution trace information there
    """
    global _trace_output_file
    _trace_output_file = _maybe_open_file(f)


def print_percentages(enabled=True):
    """
    Pass True if you want statistics to be output as percentages of total
    run time instead of absolute measurements.
    """
    global _print_percentages
    _print_percentages = enabled


def time_blocking(enabled=True):
    """
    Pass True if you want to count time blocking on IO towards the execution
    totals for each function.  The default setting for this is False, which
    is probably what you're looking for in most cases.
    """
    global _time_blocking
    _time_blocking = enabled


def set_attach_duration(attach_duration=60):
    """
    Set the duration that attach/detach are allowed to operate for.
    Will automatically detach after that time if any profile call is made.
    By default this time period is 60 seconds. Set to 0 to disable.
    """
    global _attach_duration
    _attach_duration = attach_duration


def attach_on_signal(signum=signal.SIGUSR1, duration=60):
    """
    Sets up signal handlers so that, upon receiving the specified signal,
    the process starts outputting a full execution trace.  At the expiration
    of the specified duration, a summary of all the greenlet activity during
    that period is output.
    See set_summary_output and set_trace_output for information about how
    to configure where the output goes.
    By default, the signal is SIGUSR1.
    """
    new_handler = lambda signum, frame: attach(duration=duration)
    signal.signal(signum, new_handler)


if __name__ == "__main__":
    from optparse import OptionParser

    parser = OptionParser()
    parser.add_option("-a", "--stats", dest="stats",
                      help="write the stats to a file",
                      metavar="STATS_FILE")
    parser.add_option("-s", "--summary", dest="summary",
                      help="write the summary to a file",
                      metavar="SUMMARY_FILE")
    parser.add_option("-t", "--trace", dest="trace",
                      help="write the trace to a file",
                      metavar="TRACE_FILE")
    parser.add_option("-p", "--percentages", dest="percentages",
                      action='store_false',
                      help="print stats as percentages of total runtime")
    parser.add_option("-b", "--blocking", dest="blocking",
                      action='store_false',
                      help="count blocked time toward execution totals")
    (options, args) = parser.parse_args()
    if options.stats is not None:
        set_stats_output(options.stats)
    if options.summary is not None:
        set_summary_output(options.summary)
    if options.trace is not None:
        set_trace_output(options.trace)
    if options.percentages is not None:
        print_percentages()
    if options.blocking is not None:
        time_blocking()
    if len(args) < 1:
        print "what file should i be profiling?"
        sys.exit(1)
    file = args[0]

    trace_began_at = time.time()
    sys.settrace(_globaltrace)
    execfile(file)
    sys.settrace(None)
    _print_output(time.time() - trace_began_at)
