#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import argparse
from nile.api.v1 import (
    clusters,
    aggregators as na,
    extractors as ne,
    with_hints,
    modified_schema,
    extended_schema,
    Record,
)
import itertools
from collections import Counter, defaultdict
from yql.api.v1.client import YqlClient
from qb2.api.v1 import typing as qt
import codecs
from pytils import (
    date_range,
    yt_get_date_from_table as get_date,
    get_dates_from_stat,
    get_stat_headers,
    get_cluster,
    get_driver,
    StatPusher,
)

PAGE_IMP_TABLE = "//home/videolog/strm_meta/page_imp"
fatal = "_fatal"

process_record_extend = dict(
    heartbeats=int,
    refreshes=int,
    connection=str,
    true=qt.Json,
    false=qt.Json,
    _total_=qt.Json,
)
totalize_extend = dict(
    # true totalize extend
    event=str,
    has_heartbeat_after=str,
    event_count=int,
    event_share=int,
    before_20sec=int,
    before_1min=int,
    before_10min=int,
    before_30min=int,
    after_30min=int,
)
stalled_types = [
    "Init",
    "MediaError",
    "Offline",
    "Other",
    "Recover",
    "Seek",
    "SetSource",
    "VideoTrackChange",
    "NoFragLoad",
    "all",
    "rest",
]
adend_types = ("midroll", "preroll", "postroll", "notReplaced", "unknown")
stalled_types.extend(["AdEnd_{}".format(x) for x in adend_types])
fatal_types = (
    "bufferAppendError",
    "fragLoadError",
    "fragLoadTimeOut",
    "internalException",
    "manifestLoadTimeOut",
    "other",
    "all",
    "rest",
)
events = (
    ["Stalled_{}".format(x) for x in stalled_types]
    + ["fatal_{}".format(x) for x in fatal_types]
    + ["all_all"]
)


def apply_replacements(s, pairs):
    for p in pairs:
        s = s.replace(p[0], p[1])
    return s


def wrap_ref_from(ref_from):
    if ref_from.startswith("1"):
        ref_from = "other"
    if ref_from.startswith("zen"):
        return (ref_from, "_total_")
    return (ref_from, "_total_without_zen_", "_total_")


def wrap_channel(channel, path=""):
    if channel in {"-", "", None}:
        return (channel,)
    if channel.startswith("Яндекс.Новогодний"):
        return (channel, "_yandex_", "Яндекс.Новогодний (все)", "_total_")
    if channel.startswith("Яндекс."):
        return (channel, "_yandex_", "_total_")
    if channel.startswith("Спецпроекты."):
        return (channel, "_special_", "_total_")
    if channel.startswith("Youtube."):
        return (channel, "_youtube_", "_total_")
    if "2" in (path or "").split(","):
        return (channel, "_tv_channels_", "_total_")
    return (channel, "_total_")


@with_hints(
    output_schema=modified_schema(
        exclude=["true", "false", "_total_"], extend=totalize_extend
    )
)
def multiply_events(recs):
    for rec in recs:
        dct = rec.to_dict()
        for key in ("true", "false", "_total_"):
            dct.pop(key)
        for event in events:
            dct["event"] = event
            for key in ("true", "false", "_total_"):
                dct["has_heartbeat_after"] = key
                dct.update(rec[key][event])
                yield Record(**dct)


@with_hints(output_schema=extended_schema())
def totalize(recs):
    for rec in recs:
        for comb in itertools.product(
            (rec["browser"], "_total_"),
            (rec["os_family"], "_total_"),
            (rec["provider"], "_total_"),
            (rec["country"], "_total_"),
            (rec["player_version"], "_total_"),
            wrap_channel(rec["channel"]),
            (rec["connection"], "_total_"),
            # ("true", "false", "_total_"),
            (rec["with_view"], "_total_"),
            (rec["device_type"], "_total_"),
        ):
            dct = {
                "browser": comb[0],
                "os_family": comb[1],
                "provider": comb[2],
                "country": comb[3],
                "player_version": comb[4],
                "channel": comb[5],
                "connection": comb[6],
                # "has_heartbeat_after": comb[7],
                "with_view": comb[7],
                "device_type": comb[8],
            }
            nontotals = tuple(sorted(x for x in dct if dct[x] != "_total_"))
            if len(nontotals) > 2:
                continue
            for k in (
                "fielddate",
                "has_heartbeat_after",
                "event",
                # "vsid",
                # "view_time",
                "tvt",
                "sessions",
                "heartbeats",
                "refreshes",
                "tvt",
                "event_count",
                "sessions_with_event",
                "before_20sec",
                "before_1min",
                "before_10min",
                "before_30min",
                "after_30min",
            ):
                dct[k] = rec[k]
            # dct.update(rec[dct["has_heartbeat_after"]][rec["event"]])
            # if not dct["event_count"]:
            #     dct["vsid"] = ""
            yield Record(**dct)


@with_hints(output_schema=extended_schema(**process_record_extend))
# def process_record(rec, refreshes):
def process_record(recs):
    for rec in recs:
        rec = rec.to_dict()
        for av in ("true", "false", "_total_"):
            rec[av] = {}
            for e in events:
                rec[av][e] = {
                    "event_count": 0,
                    "event_share": 0,
                    "first_appearance": -1,
                    "before_20sec": 0,
                    "before_1min": 0,
                    "before_10min": 0,
                    "before_30min": 0,
                    "after_30min": 0,
                }
        try:
            hbdict = rec.get("heartbeats") or {}
            sums = [sum(hbdict[k].values()) for k in hbdict]
            rec["heartbeats"] = min(sums)
        except (KeyError, ValueError):
            rec["heartbeats"] = 0
        # rec['refreshes'] = int(refreshes > 0)
        rec["refreshes"] = int(rec["refreshes"] > 0)
        connections = Counter()
        rec["connection"] = "UNKNOWN"
        for err in rec["errors"]:
            if err["id_raw"] == "Stalled":
                reason_ = (err.get("details") or {}).get("reason") or "Other"
                if reason_ not in stalled_types:
                    reason_ = "rest"
                try:
                    connection = err["details"]["connection"]
                    connections[connection] += 1
                except KeyError:
                    pass
                av_l = str(err["rel_time"] < rec["view_time"]).lower()
                for av in (av_l, "_total_"):
                    for reason in (reason_, "all"):
                        st_id = "Stalled_{}".format(reason)
                        if reason == "AdEnd":
                            try:
                                ad_type = err["details"]["details"]["adType"]
                            except KeyError:
                                ad_type = "unknown"
                            st_id += "_{}".format(ad_type)
                        rec[av][st_id]["event_count"] += 1
                        if rec[av][st_id]["first_appearance"] == -1:
                            rec[av][st_id]["first_appearance"] = err[
                                "rel_time"
                            ]
                    rec[av]["all_all"]["event_count"] += 1
                    if rec[av]["all_all"]["first_appearance"] == -1:
                        rec[av]["all_all"]["first_appearance"] = err[
                            "rel_time"
                        ]
            elif (
                err["id"].endswith(fatal)
                and err["rel_time"] >= rec["view_time"]
            ):
                fatal_id_ = err["id"][: -len(fatal)]
                if fatal_id_ not in fatal_types:
                    fatal_id_ = "rest"
                av_l = str(err["rel_time"] < rec["view_time"]).lower()
                for av in (av_l, "_total_"):
                    for fatal_id in (
                        "fatal_{}".format(fatal_id_),
                        "fatal_all",
                    ):
                        rec[av][fatal_id]["event_count"] += 1
                        if rec[av][fatal_id]["first_appearance"] == -1:
                            rec[av][fatal_id]["first_appearance"] = err[
                                "rel_time"
                            ]
                    rec[av]["all_all"]["event_count"] += 1
                    if rec[av]["all_all"]["first_appearance"] == -1:
                        rec[av]["all_all"]["first_appearance"] = err[
                            "rel_time"
                        ]
        for av in ("true", "false", "_total_"):
            for event in events:
                rae = rec[av][event]
                rae["event_share"] = int(rec[av][event]["event_count"] > 0)
                fap = rae.pop("first_appearance")
                rae["after_30min"] = int(fap != -1)
                if fap < 1800:
                    rae["before_30min"] = 1
                if fap < 600:
                    rae["before_30min"] = 1
                if fap < 60:
                    rae["before_1min"] = 1
                if fap < 20:
                    rae["before_20sec"] = 1

        try:
            connections.pop("UNKNOWN")
        except KeyError:
            pass
        try:
            rec["connection"] = connections.most_common(1)[0][0]
        except IndexError:
            rec["connection"] = "UNKNOWN"
        yield rec


# @with_hints(output_schema=modified_schema(
#     exclude=['heartbeats'],
#     extend=process_record_extend,
# ))
@with_hints(output_schema=extended_schema(refreshes=int))
def dash_reducer(groups):
    for _, recs in groups:
        prev_rec = {}
        last_timestamp = 0
        refreshes = 0
        for rec in recs:
            if (
                prev_rec
                and rec["video_content_id"] == prev_rec["video_content_id"]
                and rec["timestamp"] - last_timestamp < 60 * 5
            ):
                refreshes += 1
                for err in rec["errors"]:
                    err["rel_time"] += prev_rec["view_time"]
                prev_rec["errors"] += rec["errors"]
                prev_rec["view_time"] += rec["view_time"]
                continue
            elif prev_rec:
                prev_rec["refreshes"] = refreshes
                yield Record(**prev_rec)
                refreshes = 0
            prev_rec = rec.to_dict()
            last_timestamp = rec["timestamp"] + rec["view_time"]
        prev_rec["refreshes"] = refreshes
        yield Record(**prev_rec)


def process_dates(
    dates,
    cluster,
    yql_client,
    report,
    pool,
    replace_mask=None,
    debug=False,
    dev_take=False,
    only_push=False
):
    if len(dates) == 1:
        date_s = str(dates[0])
    else:
        date_s = "{}_{}".format(dates[0], dates[-1])
    date_from = str(min(dates))
    date_to = str(max(dates))
    root = "//home/videoquality/vh_analytics/vh_quality_dash/{}".format(date_s)
    tmp_table = "{}/tmp".format(root)
    tmp_table_2 = "{}/tmp2".format(root)
    report_table = "{}/report".format(root)

    if only_push:
        stat_pusher = StatPusher(
            cluster,
            report=report,
            replace_mask=replace_mask,
            remote_publish=True,
        )
        stat_pusher.push(report_table)
        return

    job = cluster.job()

    tables = [
        job.table("//cubes/video-strm/{}/sessions".format(date)).project(
            "yu_hash",
            "ip",
            "user_agent",
            "os_family",
            "view_time",
            "errors",
            "provider",
            "device_type",
            "player_version",
            "country",
            "category_id",
            "timestamp",
            "vsid",
            "video_content_id",
            "heartbeats",
            browser="browser_name",
            channel=ne.custom(
                lambda x: x if x else "NO_CHANNEL", "channel"
            ).with_type(str),
            with_view=ne.custom(
                lambda x: str((x or 0) > 0).lower(), "view_time"
            ).with_type(str),
            fielddate=ne.const(str(date)).with_type(str),
        )
        for date in date_range(date_from, date_to)
    ]

    aggregators = dict(
        intensity="data",
        sessions=na.count(),
        # vsids=na.distinct_estimate("vsid"),
        heartbeats=na.sum("heartbeats"),
        refreshes=na.sum("refreshes"),
        tvt=na.sum("view_time"),
        event_count=na.sum("event_count"),
        sessions_with_event=na.sum("event_share"),
    )
    for v in [
        "before_20sec",
        "before_1min",
        "before_10min",
        "before_30min",
        "after_30min",
    ]:
        aggregators[v] = na.sum(v)

    vh_categories = (
        job.table(PAGE_IMP_TABLE)
        .unique("category_id")
        .project(
            "category",
            category_id=ne.custom(str, "category_id").with_type(str),
        )
    )

    stream = job.concat(*tables)

    if dev_take:
        stream = stream.take(1000)

    stream = (
        stream.join(vh_categories, by="category_id", type="left")
        .project(
            ne.all(exclude=["category"]),
            category=ne.custom(
                lambda x: (x if x else "other"), "category"
            ).with_type(str),
        )
        .groupby("fielddate", "yu_hash", "ip", "user_agent")
        .sort("timestamp")
        .reduce(dash_reducer)
        .map(process_record, intensity="cpu")
        .put(tmp_table)
    )

    job.run()

    groupby_args = [
        "fielddate",
        "browser",
        "channel",
        "os_family",
        "provider",
        "country",
        "player_version",
        "device_type",
        "connection",
        "with_view",
        "has_heartbeat_after",
        "event",
    ]

    job = cluster.job()

    (
        job.table(tmp_table)
        .map(multiply_events)
        .groupby(*groupby_args)
        .aggregate(**aggregators)
        .put(tmp_table_2)
    )

    job.run()

    aggregators["sessions"] = na.sum("sessions")
    aggregators["tvt"] = na.sum("tvt")
    aggregators["sessions_with_event"] = na.sum("sessions_with_event")
    # aggregators["vsids"] = na.sum("vsids")

    job = cluster.job()

    (
        job.table(tmp_table_2)
        .map(totalize, intensity="cpu")
        .groupby(*groupby_args)
        .aggregate(**aggregators)
        .put(report_table)
    )

    job.run()

    if not debug and not dev_take:
        stat_pusher = StatPusher(
            cluster,
            report=report,
            replace_mask=replace_mask,
            remote_publish=True,
        )

        stat_pusher.push(report_table)

    if not debug:
        try:
            get_driver(cluster).remove(tmp_table)
            print("removed {}".format(tmp_table))
        except Exception as e:
            print("unable to remove {}: {}".format(tmp_table, e))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--report", default="Video/Others/Strm/Stability/vh_quality_dash"
    )
    parser.add_argument("--from", default=None)
    parser.add_argument("--pool", default=None)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--dev_take", action="store_true")
    parser.add_argument("--only_push", action="store_true")
    parser.add_argument("--replace_mask", default=None)
    parser.add_argument(
        "--job_root", default="//home/videoquality/vh_analytics/tmp"
    )
    parser.add_argument("--to", default=None)
    parser.add_argument("--title", default="VH Quality Dash | YQL")
    args = parser.parse_args()
    args.templates = {"tmp_root": args.job_root}

    cluster = get_cluster(clusters, args)
    yql_client = YqlClient(
        db=os.environ["YT_PROXY"], token=os.environ["YQL_TOKEN"]
    )

    from_ = getattr(args, "from")
    to_ = getattr(args, "to")

    if from_ and to_:
        dates = date_range(from_, to_)
    else:
        stat_headers = get_stat_headers()

        last_date_from_stat = get_dates_from_stat(
            headers=stat_headers, report=args.report, dimensions=[]
        )[-1]
        available_dates = sorted(
            get_date(x)
            for x in get_driver(cluster).client.search(
                root="//cubes/video-strm",
                path_filter=(
                    lambda x: get_date(x) and x.endswith("/preprocessed")
                ),
            )
        )

        dates = [x for x in available_dates if x > last_date_from_stat]

    print("processing {}".format(dates))
    if dates:
        process_dates(
            dates,
            cluster,
            yql_client,
            args.report,
            args.pool,
            replace_mask=args.replace_mask,
            debug=args.debug,
            dev_take=args.dev_take,
            only_push=args.only_push
        )


if __name__ == "__main__":
    main()
