
import ast
import sys
import re
from StringIO import StringIO

import click
from click import style

from direct.direct_log.lib.simple_eval import SimpleEval
from direct.direct_log.lib.trace_format import Trace


PREFIX_RE = re.compile(r'^[^\[]+')


@click.command('traces', help="Formatter for trace.log", context_settings=dict(help_option_names=['-h', '--help']))
@click.option('--profiles-count', '-p', default=1000, type=int,
              help="maximum number of longest profile records to show for each trace")
@click.option('--strict/--no-strict', default=False,
              help="all line must be trace-log json array (overwise some prefix allowed)")
@click.option('--filter-expr', "-f", default=None,
              help="""filter expression, f.e. times.ela > 1 and not method.startswith('m') or profile_match_sum("ppc:", "all_ela") > times.ela/4""")
def cli(profiles_count, strict, filter_expr):
    expr = None
    if filter_expr is not None:
        expr = ast.parse(filter_expr, mode="eval")

    for line in sys.stdin:
        try:
            prefix, trace_str = '', line
            if not strict:
                pos = line.find('[')
                if pos >= 0:
                    prefix, trace_str = line[:pos], line[pos:]

            trace = Trace.decode(trace_str)
            if expr is not None:
                if not SimpleEval(expr, trace).eval():
                    continue
            click.echo(prefix + format_trace(trace, profiles_count))
        except Exception as e:
            click.echo(style("error", fg='red') + ": " + str(e))
            sys.stdout.write(line)


def format_trace(trace, profiles_count):
    buf = StringIO()
    span_start = trace.span_start.strftime("%Y-%m-%d %H:%M:%S.%f")
    cmd = trace.service + "/" + style(trace.method, fg='blue') \
          + ("/" + trace.tags if trace.tags else "")
    ids = format_ids(trace)

    times = "ela:{},cpu:{}/{},mem:{}" \
        .format(style(flt(trace.times.ela), fg='red'),
                flt(trace.times.cu),
                flt(trace.times.cs),
                format_mem(trace.times.mem)
                )

    buf.write("{span_start}   {cmd}  {times}  {ids}".format(**locals()))
    if not trace.chunk_last or trace.chunk_index > 1:
        buf.write(" chunk:{}/{}".format(trace.chunk_index, "FIN" if trace.chunk_last else "?"))

    profiles = sorted(trace.profiles, key=lambda x: x.all_ela, reverse=True)

    for profile in profiles[:profiles_count]:
        buf.write("\n    " + profile.func + ("" if not profile.tags else "/" + profile.tags))
        buf.write("\t" + flt(profile.all_ela)
                  + ("" if not profile.child_ela else "(-{})".format(flt(profile.child_ela))))
        if profile.calls > 1:
            buf.write("\tcalls:{}".format(profile.calls))
        if profile.obj_num > 0:
            buf.write("\tobj:{}".format(profile.obj_num))

    return buf.getvalue()


def flt(ela):
    if ela >= 0.001 or ela < 1e-6:
        return "{:.3f}".format(ela)
    else:
        return "{:.6f}".format(ela)


def format_mem(mem):
    if mem < 10:
        return "{:+.3f}m".format(mem)
    else:
        return "{:+.0f}m".format(mem)


def format_ids(trace):
    if trace.trace_id == trace.span_id:
        return str(trace.trace_id)
    elif trace.trace_id == trace.parent_id:
        return "{}/{}".format(trace.trace_id, trace.span_id)
    else:
        return "{}/{}/{}".format(trace.trace_id, trace.parent_id, trace.span_id)


if __name__ == '__main__':
    cli()
