# -*- coding: utf-8 -*-

from datetime import timedelta
from contextlib import suppress
from yt.wrapper import YtClient
from drive.library.py.time import now, from_string, to_timestamp, TZ_MSK


BACKEND_LOG_STREAM_PATH = "//logs/carsharing-backend-events-log"


def parse_time(s):
    fmts = [
        "%Y-%m-%d",
        "%Y-%m-%dT%H:%M:%S",
    ]
    for fmt in fmts:
        with suppress(ValueError):
            return from_string(s, fmt).replace(tzinfo=TZ_MSK)
    raise ValueError("unsupported format for time: {}".format(s))


def register(group):
    cmd = group.add_parser("yt-profile")
    cmd.set_defaults(main=main)
    cmd.add_argument("--proxy", default="hahn")
    cmd.add_argument("--log", default=BACKEND_LOG_STREAM_PATH)
    cmd.add_argument("--from-duration", type=float, default=300)
    cmd.add_argument("-B", "--begin-time", type=parse_time, default=None)
    cmd.add_argument("-E", "--end-time", type=parse_time, default=None)
    cmd.add_argument("--source", default=None)


def get_tables(yc, log, begin_time, end_time):
    log_1d = log.rstrip('/') + "/1d"
    log_5m = log.rstrip('/') + "/stream/5min"
    res_tables = []
    for table in yc.list(log_1d, sort=True):
        time = parse_time(table)
        if time >= begin_time and time < end_time:
            res_tables.append(log_1d + "/" + table)
    for table in yc.list(log_5m, sort=True):
        time = parse_time(table)
        if time >= begin_time and time < end_time:
            res_tables.append(log_5m + "/" + table)
    return res_tables


def main(opts):
    yc = YtClient(opts.proxy)
    now_time = now().replace(tzinfo=TZ_MSK)
    begin_time = opts.begin_time or (now_time - timedelta(hours=2))
    end_time = opts.end_time or now_time
    tables = get_tables(yc, opts.log, begin_time, end_time)
    edges = {}
    times = {}
    with yc.TempTable() as temp:
        yc.run_map_reduce(
            make_mapper(
                opts.from_duration, begin_time, end_time, opts.source
            ), reducer,
            source_table=tables,
            destination_table=temp,
            reduce_by=("source", "name", "parent"),
            sort_by=("source", "name", "parent"),
        )
        for row in yc.read_table(temp):
            src, name, parent = row["source"], row["name"], row["parent"]
            edges[src] = edges.get(src, {})
            edges[src][parent] = edges[src].get(parent, [])
            edges[src][parent].append(name)
            times[src] = times.get(src, {})
            times[src][name] = times[src].get(name, 0) + row["duration"]
    for src in times.keys():
        src_edges, src_times = edges[src], times[src]
        print("Handler: {}".format(src))
        for edge in src_edges.get("", []):
            print_rec(src_edges, src_times, edge, [], 1)


def print_rec(edges, times, name, stack, depth):
    print("{}Event: {} ({:.4f} ms)".format("> " * depth, name, times[name]))
    stack.append(name)
    for edge in edges.get(name, []):
        if edge in stack:
            continue
        print_rec(edges, times, edge, stack, depth+1)
    stack.pop()


# fix_name fixes name for generic events.
def fix_name(name):
    prefixes = {
        "prepare_offer_", "build_offer_", "prefetch_offer_check_",
        "restore_offer:", "prefetch_offer_start_", "corrector_",
        "build_report_", "pack_offer_", "fix_point_",
        "Sessions:", "GetReport:", "WaitRoute:DestinationSuggest:",
        "WaitGeoFeatures:DestinationSuggest:",
        "WaitGeobaseFeatures:DestinationSuggest:",
        "WaitUserGeoFeatures:DestinationSuggest:",
        "WaitUserGeobaseFeatures:DestinationSuggest:",
        "WaitUserDoubleGeoFeatures:DestinationSuggest:",
        "WaitUserDoubleGeobaseFeatures:DestinationSuggest:",
        "build_complementary_",
        "wait_previous_offer:",
    }
    bans = {
        "fix_point_features",
    }
    for prefix in prefixes:
        if name.startswith(prefix) and name not in bans:
            return prefix + "*"
    return name


def make_mapper(from_duration, begin_time, end_time, source):
    begin_unix = to_timestamp(begin_time)
    end_unix = to_timestamp(end_time)
    def mapper(row):
        if row["event"] != "EventLog":
            return
        if row["unixtime"] < begin_unix:
            return
        if row["unixtime"] >= end_unix:
            return
        if source and row["source"] != source:
            return
        starts = {}
        parents = []
        events = {}
        first_ts = 0
        last_ts = 0
        for event in row["data"]:
            name = fix_name(event.get("source", ""))
            kind = event.get("event", "")
            if not first_ts:
                first_ts = event.get("_ts", 0)
            last_ts = event.get("_ts", 0)
            if kind == "start":
                starts[name] = last_ts
                parents.append(name)
            elif kind == "finish":
                if parents[-1] != name:
                    # Invalid row.
                    return
                parents.pop()
                duration = (last_ts - starts[name]) / 1000
                event_key = (
                    row["source"], name, parents[-1] if parents else "",
                )
                events[event_key] = events.get(event_key, 0) + duration
        while parents:
            name = parents[-1]
            parents.pop()
            duration = (last_ts - starts[name]) / 1000
            event_key = (
                row["source"], name, parents[-1] if parents else "",
            )
            events[event_key] = events.get(event_key, 0) + duration
        req_duration = (last_ts - first_ts) / 1000
        if req_duration >= from_duration:
            for key, value in events.items():
                yield dict(
                    source=key[0],
                    name=key[1],
                    parent=key[2],
                    duration=value,
                )
    return mapper


def reducer(key, rows):
    last_row = None
    duration_sum = 0
    row_count = 0
    for row in rows:
        last_row = row
        duration_sum += row["duration"]
        row_count += 1
    last_row["duration"] = duration_sum / max(row_count, 1)
    yield last_row
