#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import argparse
import codecs
import itertools
import datetime
from collections import Counter, defaultdict
from nile.api.v1 import (
    clusters,
    aggregators as na,
    extractors as ne,
    with_hints,
    modified_schema,
    extended_schema,
    Record,
)
import nile.files
from yql.api.v1.client import YqlClient
from qb2.api.v1 import typing as qt
from videolog_common import (
    zen_ref_from_treatment_light,
    date_range,
    yt_get_date_from_table as get_date,
    get_dates_from_stat,
    get_stat_headers,
    get_cluster,
    get_driver,
    StatPusher,
    apply_replacements,
    wrap_ref_from,
)


TITLE = "VH Quality Dash 3 | YQL"


errs = [
    "fatal",
    "Stalled",
    "Stalled_Other",
    "Stalled_Init",
    "Stalled_MediaError",
    "Stalled_NoFragLoad",
    "Stalled_Offline",
    "Stalled_Recover",
    "Stalled_Seek",
    "Stalled_SetSource",
    "Stalled_VideoTrackChange",
    "Stalled_AdEnd_preroll",
    "Stalled_AdEnd_midroll",
    "Stalled_AdEnd_postroll",
    "Stalled_AdEnd_unknown",
    "Stalled_AdEnd_notReplaced",
    "113",
    "bufferAppendError",
    "fragLoadError",
    "fragLoadTimeOut",
    "internalException",
    "manifestLoadTimeOut",
    "fatal_other",
]


aggregators = dict(
    vsids=na.count_distinct("vsid"),
    tvt=na.sum("viewTime"),
    refreshes=na.sum("refreshes"),
)
for err in errs:
    aggregators["{}_sum".format(err)] = na.sum(err)
    aggregators["{}_sumbin".format(err)] = na.sum("{}_bin".format(err))
    aggregators["{}_sumbinf".format(err)] = na.sum("{}_binf".format(err))


class Totalizer(object):
    def __init__(self, *totfields):
        self.totfields = totfields

    @staticmethod
    def wrap_ref_from(ref_from):
        wrapped = list(wrap_ref_from(ref_from, treatment="light"))
        if (
            wrapped[0] in ("ru.yandex.quasar.app", "ru.kinopoisk")
        ):
            wrapped.append("_native_")
        return wrapped

    @staticmethod
    def wrap_channel(channel):
        if channel in {"-", "", None}:
            return (channel,)
        if channel.startswith("Яндекс.Новогодний"):
            return (channel, "_yandex_", "Яндекс.Новогодний (все)", "_total_")
        if channel.startswith("Яндекс."):
            return (channel, "_yandex_", "_total_")
        if channel.startswith("Спецпроекты."):
            return (channel, "_special_", "_total_")
        if channel.startswith("Youtube."):
            return (channel, "_youtube_", "_total_")
        # elif "2" in (path or "").split(","):
        #     return (channel, "_tv_channels_", "_total_")
        return (channel, "_total_")

    def __call__(self, recs):
        for rec in recs:
            combs = []
            for field in self.totfields:
                if field == "ref_from":
                    combs.append(self.wrap_ref_from(rec[field] or "-"))
                elif field == "channel":
                    combs.append(self.wrap_channel(rec[field] or "-"))
                else:
                    combs.append((rec[field] or "-", "_total_"))
            for comb in itertools.product(*combs):
                key = {
                    self.totfields[i]: comb[i]
                    for i in range(len(self.totfields))
                }
                yield Record(rec, **key)


def totalizer(*args):
    return with_hints(output_schema=extended_schema())(Totalizer(*args))


def unwrap_errors(*totfields):
    schema = {x: str for x in totfields}
    schema.update(
        dict(
            fielddate=str,
            error=str,
            vsids=int,
            tvt=int,
            refreshes=int,
            errors_total=int,
            vsids_with_error=int,
            vsids_with_error_fatal=int,
        )
    )
    return with_hints(output_schema=schema)(UnwrapErrors(*totfields))


class UnwrapErrors(object):
    def __init__(self, *totfields):
        self.totfields = list(totfields) + ["fielddate", "vsids", "tvt"]

    def __call__(self, recs):
        for rec in recs:
            result = {k: (rec[k] or "-") for k in self.totfields}
            for err in errs:
                result["error"] = err
                result["tvt"] = rec["tvt"] or 0
                result["refreshes"] = rec["refreshes"] or 0
                result["vsids"] = rec["vsids"] or 0
                result["errors_total"] = rec["{}_sum".format(err)] or 0
                result["vsids_with_error"] = rec["{}_sumbin".format(err)] or 0
                result["vsids_with_error_fatal"] = (
                    rec["{}_sumbinf".format(err)] or 0
                )
                yield Record(**result)


def process_dates(
    dates,
    cluster,
    yql_client,
    report,
    pool,
    replace_mask=None,
    debug=False,
    skip_push=False,
    redo=False,
):
    if len(dates) == 1:
        date_s = str(dates[0])
    else:
        date_s = "{}_{}".format(dates[0], dates[-1])
    date_from = str(min(dates))
    date_to = str(max(dates))
    root = "//home/videoquality/vh_analytics/vh_quality_dash_3/{}".format(
        date_s
    )
    tmp_table = "{}/tmp".format(root)

    if not get_driver(cluster).exists(tmp_table) or redo:
        with open("dash_3.sql", "r") as f:
            query = f.read()
        query = apply_replacements(
            query,
            {
                "@date_from": date_from,
                "@date_to": date_to,
                "@pool": pool,
                "--@ref_from": zen_ref_from_treatment_light,
                "@output_table": tmp_table,
            },
        )
        req = yql_client.query(query, title=TITLE, syntax_version=1)
        req.run()
        req.wait_progress()
    else:
        print("using existing {}".format(tmp_table))

    job = cluster.job()

    fieldgroups = (
        ["ref_from", "player_version"],
        ["ref_from", "provider", "country"],
        ["ref_from", "device_type", "os_family", "browser_name"],
        ["ref_from", "with_view"],
        ["ref_from", "channel"],
        ["ref_from", "view_type", "os_family", "player_version"],
    )

    for fieldgroup in fieldgroups:
        fieldgroup_joined = "_".join(fieldgroup)
        output_table = "{}/{}".format(root, fieldgroup_joined)
        if get_driver(cluster).exists(output_table) and not redo:
            print("table {} already exists, skipping".format(output_table))
            continue

        stream = job.table(tmp_table)
        if "with_view" in fieldgroup:
            stream = stream.project(
                ne.all(),
                with_view=ne.custom(
                    lambda x: str(x > 0), "viewTime"
                ).with_type(str),
            )
        stream = (
            stream.map(
                totalizer(*fieldgroup),
                files=[nile.files.LocalFile("videolog_common.py")],
            )
            .groupby("fielddate", *fieldgroup)
            .aggregate(**aggregators)
            .map(unwrap_errors(*fieldgroup))
            .put(output_table)
        )
    dt = datetime.datetime.now()
    job.run()
    print(
        "multiplying took {} minutes".format(
            round((datetime.datetime.now() - dt).total_seconds() / 60.0)
        )
    )

    if not skip_push:
        for fieldgroup in fieldgroups:
            fieldgroup_joined = "_".join(fieldgroup)
            table = "{}/{}".format(root, fieldgroup_joined)
            rows = get_driver(cluster).get_attribute(table, "row_count")
            remote_publish = rows >= 1000000
            report_ = "{}/{}".format(report, fieldgroup_joined)
            print(
                "pushing from {} ({} rows) to {}, remote_publish is {}".format(
                    table, rows, report_, remote_publish
                )
            )
            stat_pusher = StatPusher(
                cluster,
                report=report_,
                replace_mask=replace_mask,
                remote_publish=remote_publish,
            )
            dt = datetime.datetime.now()
            stat_pusher.push(table)
            print(
                "pushing took {} minutes".format(
                    round(
                        (datetime.datetime.now() - dt).total_seconds() / 60.0
                    )
                )
            )

    if not debug:
        get_driver(cluster).remove(tmp_table)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--report", default="Video/Others/Strm/Stability/Dash3"
    )
    parser.add_argument("--from", default=None)
    parser.add_argument("--pool", default=None)
    parser.add_argument("--redo", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--dev_take", action="store_true")
    parser.add_argument("--only_push", action="store_true")
    parser.add_argument("--skip_push", action="store_true")
    parser.add_argument("--replace_mask", default=None)
    parser.add_argument(
        "--job_root", default="//home/videoquality/vh_analytics/tmp"
    )
    parser.add_argument("--to", default=None)
    parser.add_argument("--title", default="VH Quality Dash | YQL")
    args = parser.parse_args()
    args.templates = {"tmp_root": args.job_root}

    cluster = get_cluster(clusters, args)
    yql_client = YqlClient(
        db=os.environ["YT_PROXY"], token=os.environ["YQL_TOKEN"]
    )

    from_ = getattr(args, "from")
    to_ = getattr(args, "to")

    if from_ and to_:
        dates = date_range(from_, to_)
    else:
        stat_headers = get_stat_headers()

        last_date_from_stat = get_dates_from_stat(
            headers=stat_headers,
            report=args.report + "/ref_from_player_version",
            dimensions=[],
        )[-1]
        available_preprocessed = set(
            get_date(x)
            for x in get_driver(cluster).client.search(
                root="//cubes/video-strm",
                path_filter=(
                    lambda x: get_date(x) and x.endswith("/preprocessed")
                ),
            )
        )
        available_sessions = set(
            get_date(x)
            for x in get_driver(cluster).client.search(
                root="//cubes/video-strm",
                path_filter=(
                    lambda x: get_date(x) and x.endswith("/sessions")
                ),
            )
        )
        available_dates = available_preprocessed & available_sessions

        dates = sorted(x for x in available_dates if x > last_date_from_stat)

    print("processing {}".format(dates))
    if dates:
        process_dates(
            dates,
            cluster,
            yql_client,
            args.report,
            args.pool,
            replace_mask=args.replace_mask,
            debug=args.debug,
            redo=args.redo,
            skip_push=args.skip_push,
        )


if __name__ == "__main__":
    main()
