#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import argparse
from nile.api.v1 import (
    clusters,
    aggregators as na,
    with_hints,
    extended_schema,
    Record,
)
import nile.files as nf
import itertools
from collections import Counter, defaultdict
try:
    from yql.api.v1.client import YqlClient
except ImportError:
    pass
import codecs
from videolog_common import (
    zen_ref_from_treatment,
    ref_from_treatment,
    wrap_ref_from,
    apply_replacements,
    date_range,
    yt_get_date_from_table as get_date,
    get_dates_from_stat,
    get_stat_headers,
    get_cluster,
    get_driver,
    StatPusher,
)


def wrap_channel(channel, path):
    if channel in {"-", "", None}:
        return (channel,)
    if channel.startswith("Яндекс.Новогодний"):
        return (channel, "_yandex_", "Яндекс.Новогодний (все)", "_total_")
    elif channel.startswith("Яндекс."):
        return (channel, "_yandex_", "_total_")
    elif channel.startswith("Спецпроекты."):
        return (channel, "_special_", "_total_")
    elif channel.startswith("Youtube."):
        return (channel, "_youtube_", "_total_")
    elif "2" in (path or "").split(","):
        return (channel, "_tv_channels_", "_total_")
    return (channel, "_total_")


def get_unhappy_end(ev_pair, ad_is_bad=False, cp_is_good=False):
    good_events = {"heartbeat", "start", None, "PlayerFrameUnload"}
    if not ad_is_bad:
        good_events.add("rtb-dsp")
    if cp_is_good:
        good_events.add("create_player")
    return sum(int(x not in good_events) for x in ev_pair) / len(ev_pair)


@with_hints(
    output_schema=extended_schema(
        unhappy_end=float,
        unhappy_end_cp=float,
        unhappy_end_strict=float,
        unhappy_end_ad=float,
        unhappy_end_ad_strict=float,
    )
)
def totalize(recs):
    for rec in recs:
        for comb in itertools.product(
            (rec.browser, "_total_"),
            (rec.os_family, "_total_"),
            (rec.provider, "_total_"),
            (rec.country, "_total_"),
            wrap_ref_from(rec.ref_from),
            (rec.player_version, "_total_"),
            (rec.has_heartbeat, "_total_"),
            # (rec.connection, '_total_'),
        ):
            dct = {
                "browser": comb[0],
                "os_family": comb[1],
                "provider": comb[2],
                "country": comb[3],
                "ref_from": comb[4],
                "player_version": comb[5],
                "has_heartbeat": comb[6],
                # 'connection': comb[6],
            }
            nontotals = tuple(sorted(
                x for x in dct if dct[x] != "_total_" and x != "has_heartbeat"
            ))
            ok = False
            if len(nontotals) <= 2:
                ok = True
            if ok:
                ev_pair = (rec.penultimate_event, rec.last_event)
                dct["unhappy_end"] = get_unhappy_end(ev_pair)
                dct["unhappy_end_ad"] = get_unhappy_end(
                    ev_pair, ad_is_bad=True
                )
                dct["unhappy_end_cp"] = get_unhappy_end(
                    ev_pair, cp_is_good=True
                )
                for m in ("unhappy_end", "unhappy_end_ad"):
                    dct["{}_strict".format(m)] = int(dct[m] > 0)
                yield Record(rec, **dct)


def counter_mean(counter):
    sum_ = sum(k * counter[k] for k in counter)
    count = sum(counter.values())
    return sum_ / count


@with_hints(output_schema=extended_schema())
def unhappy_end_reduce(groups):
    for key, recs in groups:
        cntr_dict = defaultdict(Counter)
        for rec in recs:
            for value in (
                "unhappy_end",
                "unhappy_end_ad",
                "unhappy_end_cp",
                "unhappy_end_strict",
                "unhappy_end_ad_strict",
            ):
                cntr_dict[value][getattr(rec, value, 0)] += rec.count
        add_result = {"count": sum(cntr_dict["unhappy_end"].values())}
        for value in cntr_dict:
            add_result[value] = counter_mean(cntr_dict[value])
        yield Record(key, **add_result)


def process_dates(
    dates,
    cluster,
    yql_client,
    report,
    pool,
    replace_mask=None,
    debug=False,
    root=None,
    query_file=None,
    skip_push=False,
    skip_report=False,
):
    if len(dates) == 1:
        date_s = str(dates[0])
    else:
        date_s = "{}_{}".format(dates[0], dates[-1])
    date_from = str(min(dates))
    date_to = str(max(dates))
    root_ = "{}/{}".format(root, date_s)
    tmp_table = "{}/preaggr".format(root_)
    by_vsid_table = "{}/by_vsid".format(root_)
    report_table = "{}/report".format(root_)

    with codecs.open(query_file, "r", "utf8") as f:
        query = f.read()
    base_replacements = [
        ("@date_from", date_from),
        ("@date_to", date_to),
        ("@pool", pool),
        ("@output_table", tmp_table),
        ("@by_vsid_table", by_vsid_table),
        ("--@ref_from", zen_ref_from_treatment),
    ]
    query = apply_replacements(query, base_replacements)

    req = yql_client.query(
        query, title="UnhappyEnd | YQL", syntax_version=1
    )
    req.run()
    req.wait_progress()

    job = cluster.job()

    if not skip_report:
        job.table(tmp_table).map(
            totalize, intensity="ultra_cpu", files=[
                nf.LocalFile("videolog_common.py")
            ]
        ).groupby(
            "fielddate",
            "browser",
            "os_family",
            "provider",
            "country",
            "ref_from",
            "player_version",
            "has_heartbeat",
            "unhappy_end",
            "unhappy_end_ad",
            "unhappy_end_cp"
        ).aggregate(count=na.sum("count")).groupby(
            "fielddate",
            "browser",
            "os_family",
            "provider",
            "country",
            "ref_from",
            "player_version",
            "has_heartbeat",
        ).reduce(
            unhappy_end_reduce
        ).put(
            report_table
        )

    job.run()

    if not skip_push:
        stat_pusher = StatPusher(
            cluster,
            report=report,
            replace_mask=replace_mask,
            remote_publish=True,
        )
        stat_pusher.push(report_table)

    if not debug:
        try:
            get_driver(cluster).remove(tmp_table)
            print("removed {}".format(tmp_table))
            get_driver(cluster).remove(by_vsid_table)
            print("removed {}".format(by_vsid_table))
        except:
            print("unable to remove {}".format(tmp_table))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--report", default="Video/Others/Strm/Stability/UnhappyEndHB2"
    )
    parser.add_argument("--from", default=None)
    parser.add_argument("--pool", default=None)
    parser.add_argument(
        "--root", default="//home/videoquality/vh_analytics/unhappyend_test"
    )
    parser.add_argument("--query_file", default="unhappyend_stub_2.sql")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--replace_mask", default=None)
    parser.add_argument(
        "--job_root", default="//home/videoquality/vh_analytics/tmp"
    )
    parser.add_argument("--to", default=None)
    parser.add_argument("--title", default="UnhappyEnd | YQL")
    parser.add_argument("--skip_push", action="store_true")
    parser.add_argument("--skip_report", action="store_true")
    args = parser.parse_args()
    args.templates = {"tmp_root": args.job_root}

    cluster = get_cluster(clusters, args)
    yql_client = YqlClient(
        db=os.environ["YT_PROXY"], token=os.environ["YQL_TOKEN"]
    )

    from_ = getattr(args, "from")
    to_ = getattr(args, "to")

    if from_ and to_:
        dates = date_range(from_, to_)
    else:
        stat_headers = get_stat_headers()

        last_date_from_stat = get_dates_from_stat(
            headers=stat_headers, report=args.report, dimensions=[]
        )[-1]
        available_dates = sorted(
            get_date(x)
            for x in get_driver(cluster).client.search(
                root="//cubes/video-strm",
                path_filter=(
                    lambda x: get_date(x) and x.endswith("/preprocessed")
                ),
            )
        )

        dates = [x for x in available_dates if x > last_date_from_stat]

    print("processing {}".format(dates))
    if dates:
        process_dates(
            dates,
            cluster,
            yql_client,
            args.report,
            args.pool,
            replace_mask=args.replace_mask,
            debug=args.debug,
            query_file=args.query_file,
            root=args.root,
            skip_push=args.skip_push,
            skip_report=args.skip_report
        )


if __name__ == "__main__":
    main()
