#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import argparse
import codecs
import itertools
from collections import Counter, defaultdict
from nile.api.v1 import (
    clusters,
    aggregators as na,
    extractors as ne,
    with_hints,
    modified_schema,
    extended_schema,
    Record,
)
from yql.api.v1.client import YqlClient
from qb2.api.v1 import typing as qt
from pytils import (
    date_range,
    yt_get_date_from_table as get_date,
    get_dates_from_stat,
    get_stat_headers,
    get_cluster,
    get_driver,
    StatPusher,
)

PAGE_IMP_TABLE = "//home/videolog/strm_meta/page_imp"
fatal = "_fatal"

process_record_extend = dict(
    heartbeats=int,
    refreshes=int,
    connection=str,
    true=qt.Json,
    false=qt.Json,
    _total_=qt.Json,
)
totalize_extend = dict(
    # true totalize extend
    event=str,
    has_heartbeat_after=str,
    event_count=int,
    event_share=int,
    before_20sec=int,
    before_1min=int,
    before_10min=int,
    before_30min=int,
    after_30min=int,
)
stalled_types = [
    "Init",
    "MediaError",
    "Offline",
    "Other",
    "Recover",
    "Seek",
    "SetSource",
    "VideoTrackChange",
    "NoFragLoad",
    "all",
    "rest",
]
adend_types = ("midroll", "preroll", "postroll", "notReplaced", "unknown")
stalled_types.extend(["AdEnd_{}".format(x) for x in adend_types])
fatal_types = (
    "bufferAppendError",
    "fragLoadError",
    "fragLoadTimeOut",
    "internalException",
    "manifestLoadTimeOut",
    "other",
    "all",
    "rest",
)
events = (
    ["Stalled_{}".format(x) for x in stalled_types]
    + ["fatal_{}".format(x) for x in fatal_types]
    + ["all_all"]
)


def apply_replacements(s, pairs):
    for p in pairs:
        s = s.replace(p[0], p[1])
    return s


def wrap_ref_from(ref_from):
    if ref_from.startswith("1"):
        ref_from = "other"
    if ref_from.startswith("zen"):
        return (ref_from, "_total_")
    return (ref_from, "_total_without_zen_", "_total_")


def wrap_channel(channel, path=""):
    if channel in {"-", "", None}:
        return (channel,)
    if channel.startswith("Яндекс.Новогодний"):
        return (channel, "_yandex_", "Яндекс.Новогодний (все)", "_total_")
    if channel.startswith("Яндекс."):
        return (channel, "_yandex_", "_total_")
    if channel.startswith("Спецпроекты."):
        return (channel, "_special_", "_total_")
    if channel.startswith("Youtube."):
        return (channel, "_youtube_", "_total_")
    if "2" in (path or "").split(","):
        return (channel, "_tv_channels_", "_total_")
    return (channel, "_total_")


@with_hints(
    output_schema=modified_schema(
        exclude=["true", "false", "_total_"], extend=totalize_extend
    )
)
def multiply_events(recs):
    for rec in recs:
        dct = rec.to_dict()
        for key in ("true", "false", "_total_"):
            dct.pop(key)
        for event in events:
            dct["event"] = event
            for key in ("true", "false", "_total_"):
                dct["has_heartbeat_after"] = key
                dct.update(rec[key][event])
                yield Record(**dct)


def totals(keys):
    return {k: "_total_" for k in keys}


@with_hints(output_schema=extended_schema())
def totalize(recs):
    keys = {
        "browser_name",
        "channel",
        "connection",
        "country",
        "device_type",
        "os_family",
        "player_version",
        "provider",
        "ref_from",
        "with_view",
    }
    for rec in recs:
        yield rec
        yield Record(rec, **totals(keys))
        for key in keys:
            yield Record(rec, **totals(keys - {key}))


def get_real_id(err):
    if err["id"].startswith("Stalled"):
        try:
            reason = err["details"]["reason"]
        except KeyError:
            reason = "Other"
        if reason != "AdEnd":
            return "Stalled_{}".format(reason)
        else:
            try:
                adtype = err["details"]["details"]["adType"]
            except KeyError:
                adtype = "unknown"
            return "Stalled_AdEnd_{}".format(adtype)
    if err["id"].endswith("_fatal"):
        return "fatal_{}".format(err["id"][: -len("_fatal")])
    return err["id"]


def get_err_list(real_id):
    err_list = [real_id, "_all_errors_"]
    if real_id == "Stalled_Other":
        err_list.append("_stalled_other_or_fatal_")
    if real_id.startswith("Stalled"):
        err_list.append("_stalled_")
        err_list.append("_stalled_or_fatal_")
    elif real_id.startswith("fatal"):
        err_list.append("_fatal_")
        err_list.append("_stalled_or_fatal_")
        err_list.append("_stalled_other_or_fatal_")
    return err_list


def update_counts(info, rec, err, connections):
    real_id = get_real_id(err)
    for id_ in get_err_list(real_id):
        info[id_]["first_appearance"] = min(
            err["rel_time"], info[id_]["first_appearance"]
        )
        info[id_]["count"] += 1
        if err["rel_time"] >= rec["view_time"]:
            info[id_]["fatal"] = 1
    try:
        connection = err["details"]["connection"]
        connections[connection] += 1
    except KeyError:
        pass


@with_hints(
    output_schema=modified_schema(
        exclude=[
            "ip",
            "errors",
            "category_id",
            "timestamp",
            "video_content_id",
            "heartbeats",
        ],
        extend=dict(
            event=str,
            connection=str,
            event_count=int,
            first_appearance=int,
            before_20sec=int,
            before_1min=int,
            before_10min=int,
            before_30min=int,
            after_30min=int,
            fatal=int,
            fatal_vsid=str,
            fatal_yu_hash=str,
            with_view=str,
        ),
    )
)
def process_record(recs):
    for rec in recs:
        info = defaultdict(
            lambda: {"count": 0, "first_appearance": 999, "fatal": 0}
        )
        connections = Counter()
        for err in sorted(rec["errors"], key=lambda x: x["rel_time"]):
            update_counts(info, rec, err, connections)
        info["_total_"]["count"] = 1
        info["_total_"]["first_appearance"] = 0
        try:
            connections.pop("UNKNOWN")
        except KeyError:
            pass
        try:
            connection = connections.most_common(1)[0][0]
        except IndexError:
            connection = "UNKNOWN"
        for err in info:
            yield Record(
                browser_name=rec["browser_name"],
                category=rec["category"],
                channel=rec["channel"],
                country=rec["country"],
                device_type=rec["device_type"],
                fielddate=rec["fielddate"],
                os_family=rec["os_family"],
                player_version=rec["player_version"],
                provider=rec["provider"],
                ref_from=rec["ref_from"],
                refreshes=rec["refreshes"],
                view_type=rec["view_type"],
                vsid=rec["vsid"],
                yu_hash=rec["yu_hash"],
                view_time=rec["view_time"],
                # new fields
                after_30min=int(info[err]["first_appearance"] > 1800),
                before_10min=int(info[err]["first_appearance"] <= 600),
                before_1min=int(info[err]["first_appearance"] <= 60),
                before_20sec=int(info[err]["first_appearance"] <= 20),
                before_30min=int(info[err]["first_appearance"] <= 1800),
                connection=connection,
                event=err,
                event_count=info[err]["count"],
                fatal=info[err]["fatal"],
                fatal_vsid=rec["vsid"] if info[err]["fatal"] else "",
                fatal_yu_hash=rec["yu_hash"] if info[err]["fatal"] else "",
                first_appearance=info[err]["first_appearance"],
                with_view=str(rec["view_time"] > 0).lower(),
            )


# @with_hints(output_schema=modified_schema(
#     exclude=['heartbeats'],
#     extend=process_record_extend,
# ))
@with_hints(output_schema=extended_schema(refreshes=int))
def dash_reducer(groups):
    for _, recs in groups:
        prev_rec = {}
        last_timestamp = 0
        refreshes = 0
        for rec in recs:
            if (
                prev_rec
                and rec["video_content_id"] == prev_rec["video_content_id"]
                and rec["timestamp"] - last_timestamp < 60 * 5
            ):
                refreshes += 1
                for err in rec["errors"]:
                    err["rel_time"] += prev_rec["view_time"]
                prev_rec["errors"] += rec["errors"]
                prev_rec["view_time"] += rec["view_time"]
                continue
            elif prev_rec:
                prev_rec["refreshes"] = refreshes
                yield Record(**prev_rec)
                refreshes = 0
            prev_rec = rec.to_dict()
            last_timestamp = rec["timestamp"] + rec["view_time"]
        prev_rec["refreshes"] = refreshes
        yield Record(**prev_rec)


def _add_shares_yield(rec, total):
    return Record(
        rec,
        session_share=rec["sessions"] / total["sessions"],
        vsid_share=rec["vsids"] / total["vsids"],
        yu_hash_share=rec["yu_hashes"] / total["yu_hashes"],
        total_vsids=total["vsids"],
        total_yu_hashes=total["yu_hashes"],
        total_sessions=total["yu_hashes"],
    )


@with_hints(output_schema=extended_schema(
    session_share=float,
    vsid_share=float,
    yu_hash_share=float,
    total_vsids=int,
    total_sessions=int,
    total_yu_hashes=int
))
def add_shares(groups):
    for _, recs in groups:
        buffer_ = []
        total = None
        for rec in recs:
            if rec["event"] == "_total_":
                total = rec
                yield Record(
                    rec,
                    session_share=1.0,
                    vsid_share=1.0,
                    yu_hash_share=1.0,
                    total_vsids=rec["vsids"],
                    total_sessions=rec["sessions"],
                    total_yu_hashes=rec["yu_hashes"],
                )
            elif total:
                yield _add_shares_yield(rec, total)
            else:
                buffer_.append(rec)
        for rec in buffer_:
            yield _add_shares_yield(rec, total)


def first_step(cluster, date_from, date_to, tmp_table, dev_take):
    job = cluster.job()

    tables = [
        job.table("//cubes/video-strm/{}/sessions".format(date)).project(
            "yu_hash",
            "ip",
            "user_agent",
            "os_family",
            "view_time",
            "ref_from",
            "view_type",
            "errors",
            "provider",
            "device_type",
            "player_version",
            "country",
            "category_id",
            "timestamp",
            "vsid",
            "video_content_id",
            "heartbeats",
            "browser_name",
            channel=ne.custom(
                lambda x: x if x else "NO_CHANNEL", "channel"
            ).with_type(str),
            fielddate=ne.const(str(date)).with_type(str),
        )
        for date in date_range(date_from, date_to)
    ]

    vh_categories = (
        job.table(PAGE_IMP_TABLE)
        .unique("category_id")
        .project(
            "category",
            category_id=ne.custom(str, "category_id").with_type(str),
        )
    )

    stream = job.concat(*tables)

    if dev_take:
        stream = stream.take(1000)

    stream = (
        stream.join(vh_categories, by="category_id", type="left")
        .project(
            ne.all(exclude=["category"]),
            category=ne.custom(
                lambda x: (x if x else "other"), "category"
            ).with_type(str),
        )
        .groupby("fielddate", "yu_hash", "ip", "user_agent")
        .sort("timestamp")
        .reduce(dash_reducer)
        .map(process_record)
        .put(tmp_table)
    )

    job.run()


def process_dates(
    dates,
    cluster,
    yql_client,
    report,
    pool,
    replace_mask=None,
    debug=False,
    dev_take=False,
    only_push=False,
    skip_first=False
):
    if len(dates) == 1:
        date_s = str(dates[0])
    else:
        date_s = "{}_{}".format(dates[0], dates[-1])
    date_from = str(min(dates))
    date_to = str(max(dates))
    root = "//home/videoquality/vh_analytics/vh_quality_dash_lite/{}".format(
        date_s
    )
    tmp_table = "{}/tmp".format(root)
    tmp_table_2 = "{}/tmp2".format(root)
    report_table = "{}/report".format(root)

    if only_push:
        stat_pusher = StatPusher(
            cluster,
            report=report,
            replace_mask=replace_mask,
            remote_publish=True,
        )
        stat_pusher.push(report_table)
        return

    if not skip_first:
        first_step(cluster, date_from, date_to, tmp_table, dev_take)

    groupby_args = [
        "browser_name",
        "channel",
        "connection",
        "country",
        "device_type",
        "fielddate",
        "os_family",
        "player_version",
        "provider",
        "ref_from",
        "with_view",
    ]
    groupby_args.append("event")  # must come last so second groupby works
    aggregators = dict(
        intensity="data",
        sessions=na.count(),
        vsids=na.count_distinct_estimate("vsid"),
        fatal_vsids=na.count_distinct_estimate("fatal_vsid"),
        yu_hashes=na.count_distinct_estimate("yu_hash"),
        fatal_yu_hashes=na.count_distinct_estimate("fatal_yu_hash"),
        refreshes=na.sum("refreshes"),
        tvt=na.sum("view_time"),
    )
    for v in [
        "refreshes",
        "event_count",
        "before_20sec",
        "before_1min",
        "before_10min",
        "before_30min",
        "after_30min",
    ]:
        aggregators[v] = na.sum(v)

    job = cluster.job()

    (
        job.table(tmp_table)
        .map(totalize, intensity="cpu")
        .groupby(*groupby_args)
        .aggregate(**aggregators)
        .put(tmp_table_2)
    )

    job.run()

    job = cluster.job()

    (
        job.table(tmp_table_2)
        .groupby(*groupby_args[:-1])
        .reduce(add_shares)
        .put(report_table)
    )

    job.run()

    if not debug and not dev_take:
        stat_pusher = StatPusher(
            cluster,
            report=report,
            replace_mask=replace_mask,
            remote_publish=True,
        )

        stat_pusher.push(report_table)

    if not debug:
        for table in [tmp_table, tmp_table_2]:
            if get_driver(cluster).exists(table):
                get_driver(cluster).remove(table)
                print("removed {}".format(table))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--report", default="Video/Others/Strm/Stability/vh_quality_dash_lite"
    )
    parser.add_argument("--from", default=None)
    parser.add_argument("--pool", default=None)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--dev_take", action="store_true")
    parser.add_argument("--only_push", action="store_true")
    parser.add_argument("--replace_mask", default=None)
    parser.add_argument(
        "--job_root", default="//home/videoquality/vh_analytics/tmp"
    )
    parser.add_argument("--to", default=None)
    parser.add_argument("--title", default="VH Quality Dash | YQL")
    args = parser.parse_args()
    args.templates = {"tmp_root": args.job_root}

    cluster = get_cluster(clusters, args)
    yql_client = YqlClient(
        db=os.environ["YT_PROXY"], token=os.environ["YQL_TOKEN"]
    )

    from_ = getattr(args, "from")
    to_ = getattr(args, "to")

    if from_ and to_:
        dates = date_range(from_, to_)
    else:
        stat_headers = get_stat_headers()

        last_date_from_stat = get_dates_from_stat(
            headers=stat_headers, report=args.report, dimensions=[]
        )[-1]
        available_dates = sorted(
            get_date(x)
            for x in get_driver(cluster).client.search(
                root="//cubes/video-strm",
                path_filter=(
                    lambda x: get_date(x) and x.endswith("/sessions")
                ),
            )
        )

        dates = [x for x in available_dates if x > last_date_from_stat]

    print("processing {}".format(dates))
    if dates:
        for date in dates:
            process_dates(
                [date],
                cluster,
                yql_client,
                args.report,
                args.pool,
                replace_mask=args.replace_mask,
                debug=args.debug,
                dev_take=args.dev_take,
                only_push=args.only_push,
            )


if __name__ == "__main__":
    main()
