#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import division
import os
import argparse
import itertools
import codecs
import datetime
import requests
from nile.api.v1 import (
    clusters,
    filters as nf,
    extractors as ne,
    aggregators as na,
    statface as ns,
    with_hints,
    extended_schema,
    Record,
)
import nile.files as nfi
from qb2.api.v1 import typing as qt
from yql.api.v1.client import YqlClient
from videolog_common import (
    date_range,
    get_date,
    get_missing_dates_from_stat,
    get_stat_headers,
    get_cluster,
    get_driver,
    ref_from_treatment,
    wrap_ref_from,
    StatPusher,
    YqlRunner,
)

files = [nfi.LocalFile("videolog_common.py")]
PYTHON3_LAYERS = [
    "//porto_layers/base/bionic/porto_layer_search_ubuntu_bionic_app_lastest.tar.gz",
    "//home/portolayer/bionic/python3/latest",
]

SLA_REPORT = "Video/Others/Strm/strm_cube_2_SLA"
ROOT = "//home/videoquality/vh_analytics/vh_cube_2"
SESSIONS_ROOT = "//cubes/video-strm"


def apply_replacements(s, pairs):
    for p in pairs:
        s = s.replace(p[0], p[1])
    return s


def wrap_channel(
    channel, path, ref_from, channel_old, program
):  # TODO: remove zen hack
    result = []
    if channel.startswith("YANDEXSHOW__"):
        channel = channel.split("YANDEXSHOW__")[1]
        result.append("_yandex_show_")
    for spec_channel in ("Персональный канал", "Моя музыка"):
        if channel.startswith("Яндекс.{}.".format(spec_channel)):
            channel = channel.split("Яндекс.{}.".format(spec_channel))[1]
            result.append("Яндекс.{}".format(spec_channel))
    result.append(channel)
    if channel not in {"-", "", None} or ref_from.startswith(
        "zen"
    ):  # TODO: remove zen hack
        result.append("_total_")
    if channel == "ott.trailers":
        result.append("ott")
    if "НХЛ" in channel or "NHL" in channel:
        result.append("_nhl_")
    if "ФНЛ" in channel:
        result.append("_fnl_")
    if "Яндекс.Уроки" in channel:
        result.append("_yandex_education_")
    if channel.startswith(("UGC", "Zen.")):
        if channel != "UGC.Канал Ether":  # нарезки
            result.append("_ugc_")
        if (channel_old == "zen" or channel.startswith("Zen.")):
            result.append("_zen_ugc_")
            channel = "ZenUGC.{}".format(channel[4:])
        else:
            if channel != "UGC.Канал Ether":  # нарезки
                result.append("_vh_ugc_")
        if (program or "").startswith("UGCLive."):
            result.append("_ugc_live_")
        result.append("ugc_bucket_{}".format(channel_old))
    if channel.startswith("fm_"):
        result.append("_fm_")
    if channel.startswith("Яндекс.Новогодний"):
        result.append("_yandex_")
        result.append("Яндекс.Новогодний (все)")
    elif channel.startswith("Яндекс."):
        result.append("_yandex_")
    elif channel.startswith("Спецпроекты."):
        result.append("_special_")
    elif channel.startswith("Youtube."):
        result.append("_youtube_")
    elif (
        "2" in (path or "").split(",")
        and not any([x.startswith("_") and x != "_total_" for x in result])
        and not any([x.startswith("Яндекс") for x in result])
        and not any([x.startswith("Погода") for x in result])
        and not any([x.startswith("Я.Новости") for x in result])
        and not any(["VOD" in x for x in result])
        and not any(["ya-news" in x for x in result])
    ):
        result.append("_tv_channels_")
    if "_youtube_" in result or "_vh_ugc_" in result:
        result.append("_vh_ugc_youtube_")
    return result


class BytesStructWrapper:
    def __init__(self, obj):
        self.obj = obj
    def __getattr__(self, *args, **kwargs):
        result = getattr(self.obj, *args, **kwargs)
        if isinstance(result, bytes):
            return result.decode("utf8", errors="replace")
        return result


@with_hints(output_schema=extended_schema())
def totalize(recs):
    for rec_ in recs:
        rec = BytesStructWrapper(rec_)
        for comb in itertools.product(
            (rec.browser, "_total_"),
            (rec.os_family, "_total_"),
            wrap_channel(rec.channel, rec.path, rec.ref_from, rec.channel_old, rec.program),
            (rec.program, "_total_"),
            wrap_ref_from(rec.ref_from, treatment="mild"),
            (rec.view_type, "_total_"),
            (rec.country, "_total_"),
        ):
            dct = {
                "browser": comb[0],
                "os_family": comb[1],
                "channel": comb[2],
                "program": comb[3],
                "ref_from": comb[4],
                "view_type": comb[5],
                "country": comb[6],
                "provider": "_total_",
            }
            nontotals = tuple(sorted(x for x in dct if dct[x] != "_total_"))
            ok = False
            if len(nontotals) <= 1:
                ok = True
            elif nontotals in {
                ("channel", "view_type"),
                ("ref_from", "view_type"),
                ("channel", "os_family"),
                ("os_family", "view_type"),
                ("os_family", "ref_from"),
                ("channel", "ref_from"),
                ("country", "ref_from"),
                ("channel", "program"),
                ("program", "ref_from"),
                ("channel", "country"),
                ("channel", "program", "view_type"),
                ("channel", "ref_from", "view_type"),
                ("channel", "os_family", "view_type"),
                ("browser", "os_family", "view_type"),
            }:
                ok = True
            if ok:
                yield Record(rec_, **dct)


def get_is_sla(headers, dates):
    values = requests.get(
        "https://upload.stat.yandex-team.ru/_api/statreport/json/{}/?scale=d&sla_type=stat_cube".format(
            SLA_REPORT
        ),
        verify=False,
        headers=headers,
    ).json()["values"]
    fielddates = {x["fielddate"].split(" ")[0] for x in values}
    if any(str(x) not in fielddates for x in dates):
        return str(max(dates))
    return


def process_dates(
    dates,
    cluster,
    yql_client,
    report,
    pool,
    replace_mask=None,
    debug=False,
    redo=False,
    root_=None,
    sessions_root_=None,
    ibfix=False,
):
    if len(dates) == 1:
        date_s = str(dates[0])
    else:
        date_s = "{}_{}".format(dates[0], dates[-1])
    date_from = str(min(dates))
    date_to = str(max(dates))
    if not root_:
        root_ = ROOT
    if not sessions_root_:
        sessions_root_ = SESSIONS_ROOT
    root = "{}/{}".format(root_, date_s)
    tmp_table = "{}/tmp".format(root)
    report_table = "{}/report".format(root)
    yr = YqlRunner(yql_client, title="VH Cube 2 | YQL")

    if not get_driver(cluster).client.exists(tmp_table) or redo:
        with codecs.open("vh_cube_2_stub.sql", "r", "utf8") as f:
            query = f.read()
        base_replacements = [
            ("@date_from", date_from),
            ("@date_to", date_to),
            ("@pool", pool),
            ("@[root]", sessions_root_),
            ("@output_table", tmp_table),
            (
                "@[ibfix_join]",
                "left join $iron_branch as ib on (s.video_content_id == ib.JoinKey)"
                if ibfix
                else "",
            ),
            ("@[ibfix_channel]", "ib.computed_channel ?? " if ibfix else ""),
            ("@[ibfix_program]", "ib.computed_program ?? " if ibfix else ""),
        ]
        query = apply_replacements(query, base_replacements)

        yr.run(
            query,
            attachments=[
                {
                    "path": "arcadia/quality/mstand_metrics/users/24julia/yandex_shows.sql"
                },
                {
                    "path": "arcadia/analytics/videolog/strm-stats/strm_cube_2/vh_cube_2/vh_cube_2_common.sql"
                }
            ],
        )
    else:
        print("using existing table")
    if not get_driver(cluster).client.exists(tmp_table):
        raise Exception("table {} does not exist".format(tmp_table))

    job = cluster.job().env(yt_spec_defaults={
        "mapper": {"layer_paths": PYTHON3_LAYERS},
        "reducer": {"layer_paths": PYTHON3_LAYERS}
    })

    job.table(tmp_table).map(
        totalize, intensity="ultra_cpu", files=files
    ).groupby(
        "fielddate",
        "browser",
        "country",
        "os_family",
        "channel",
        "program",
        "ref_from",
        "view_type",
        "provider",
    ).aggregate(
        sessions=na.count(),
        good_sessions=na.sum("good_sessions"),
        good_tvt=na.sum("good_tvt"),
        microsessions_with_view=na.sum("with_view"),
        microsessions_with_view_30s=na.sum("with_view_30s"),
        microsessions_with_money=na.sum("with_money"),
        price=na.sum("price"),
        partner_price=na.sum("partner_price"),
        vsids=na.count_distinct("vsid"),
        video_content_ids=na.count_distinct("video_content_id"),
        users=na.count_distinct("yu_hash"),
        users_30s=na.count_distinct("yu_hash_30s"),
        tvt=na.sum("view_time"),
        tvt_non_muted=na.sum("view_time_non_muted"),
        lvt=na.sum("lvt"),
        gbytes_sent=na.sum("gbytes_sent"),
    ).filter(
        nf.or_(
            nf.custom(lambda x: x > 1800, "tvt"),
            nf.custom(lambda x: x > 1000 * 1000 * 10, "price")
        )
    ).put(
        report_table
    )

    job.run()

    if not debug:
        stat_pusher = StatPusher(
            cluster,
            report=report,
            replace_mask=replace_mask,
            remote_publish=True,
        )

        stat_pusher.push(report_table)

    finish_time = datetime.datetime.now()

    headers = get_stat_headers()
    date_for_sla = get_is_sla(headers, dates)
    if date_for_sla:
        data = [
            {
                "fielddate": date_for_sla,
                "sla_type": "stat_cube",
                "time": int(finish_time.strftime("%H%M")),
            }
        ]
        sp = StatPusher(cluster, report=SLA_REPORT)
        sp.push(data)

    if not debug:
        try:
            get_driver(cluster).remove(tmp_table)
            print("removed {}".format(tmp_table))
        except:
            print("unable to remove {}".format(tmp_table))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--report", default="Video/Others/Strm/strm_cube_2")
    parser.add_argument("--from", default=None)
    parser.add_argument("--pool", default=None)
    parser.add_argument("--ibfix", action="store_true")
    parser.add_argument("--debug", default=None)
    parser.add_argument("--redo", action="store_true")
    parser.add_argument("--replace_mask", default=None)
    parser.add_argument("--root", default=None)
    parser.add_argument("--sessions_root", default=None)
    parser.add_argument(
        "--job_root", default="//home/videoquality/vh_analytics/tmp"
    )
    parser.add_argument("--to", default=None)
    parser.add_argument("--title", default="VH Cube 2 | YQL")
    args = parser.parse_args()
    args.templates = {"tmp_root": args.job_root}

    cluster = get_cluster(clusters, args)
    yql_client = YqlClient(
        db=os.environ["YT_PROXY"], token=os.environ["YQL_TOKEN"]
    )

    from_ = getattr(args, "from")
    to_ = getattr(args, "to")

    if from_ and to_:
        dates = date_range(from_, to_)
    else:
        stat_headers = get_stat_headers()
        yesterday = datetime.date.today() - datetime.timedelta(days=1)
        weekago = yesterday - datetime.timedelta(days=7)
        missing_dates = set(
            get_missing_dates_from_stat(
                stat_headers, args.report, weekago, yesterday
            )[0]
        )
        print(
            "missing dates: {}".format(
                ",".join(sorted(map(str, missing_dates)))
            )
        )

        available_dates = set(
            get_date(x)
            for x in get_driver(cluster).client.search(
                root="//cubes/video-strm",
                path_filter=lambda x: get_date(x) and x.endswith("/sessions"),
            )
            if get_date(x) >= weekago
        )
        print(
            "available dates: {}".format(
                ",".join(sorted(map(str, available_dates)))
            )
        )

        dates = sorted(missing_dates & available_dates)

    print("processing {}".format(dates))
    if dates:
        for date in dates:
            process_dates(
                [date],
                cluster,
                yql_client,
                args.report,
                args.pool,
                replace_mask=args.replace_mask,
                debug=args.debug,
                root_=args.root,
                redo=args.redo,
                sessions_root_=args.sessions_root,
                ibfix=args.ibfix,
            )


if __name__ == "__main__":
    main()
