#!/usr/bin/env python
# -*- coding: utf-8 -*-

from collections import defaultdict, Counter
import json
import re
import urllib
import itertools
import random
import argparse
import string
import os
import sys
import time
import math
from datetime import datetime as dt, timedelta
import datetime
import nile
from nile.api.v1 import (
    clusters,
    Record,
    with_hints,
    aggregators as na,
    grouping as ng,
    filters as nf,
    extractors as ne,
)
from pytils import date_range
import requests

# где будут жить логи на кластере
job_root = None
cluster = None
g_args = None
RETRANS = "tcpinfo_total_retrans"


def merge_names(a, b):
    if a and b and a != b:
        return normalize_name(b)
    elif (not a) and b:
        return normalize_name(b)
    return normalize_name(a)


def merge_program_title(x):
    pt = x.get("program_title", "").decode("utf8")
    t = x.get("title", "").decode("utf8")
    if pt and t:
        return u"{}. {}".format(pt, t)
    return t


def transform_programs_value(dct):
    return {
        x["content_id"].decode("utf8"): truncate_string(merge_program_title(x))
        for x in dct["programms"]
    }


def read_programs_table(table):
    obj = list(cluster.read(table))[0].to_dict()
    obj = obj["programs"]
    obj = {k: transform_programs_value(v) for k, v in obj.items()}
    return obj


class AddShares(object):
    def __init__(self, old_mode=False, decimals=6):
        self.old_mode = old_mode
        self.decimals = decimals

    def __call__(self, groups):
        decimals = self.decimals
        for _, records in groups:
            before_total = []
            total = None
            total_chunks = None
            for rec in records:
                if rec.error == "_total_":
                    total = float(rec["sessions"])
                    total_chunks = float(rec["chunks_count"])
                    res = rec.to_dict()
                    res["error_share"] = 1
                    res["error_chunk_share"] = 1
                    if not self.old_mode:
                        res["bits_per_sec"] = (
                            round(res["bits_sent"] / res["tvt"])
                            if res["tvt"]
                            else 0
                        )
                    yield Record(**res)
                elif total is None:
                    before_total.append(rec.to_dict())
                else:
                    res = rec.to_dict()
                    res["error_share"] = round(
                        rec["sessions"] / total, decimals
                    )
                    res["error_chunk_share"] = round(
                        rec["chunks_count"] / total_chunks, decimals
                    )
                    if not self.old_mode:
                        res["bits_per_sec"] = (
                            round(res["bits_sent"] / res["tvt"])
                            if res["tvt"]
                            else 0
                        )
                    yield Record(**res)
            for rec in before_total:
                try:
                    rec["error_share"] = (
                        round(rec["sessions"] / total, decimals)
                        if total
                        else -1
                    )
                    rec["error_chunk_share"] = (
                        round(rec["chunks_count"] / total_chunks, decimals)
                        if total_chunks
                        else -1
                    )
                    if not self.old_mode:
                        rec["bits_per_sec"] = (
                            round(rec["bits_sent"] / rec["tvt"])
                            if rec["tvt"]
                            else 0
                        )
                except:
                    raise Exception((json.dumps(rec), total))
                yield Record(**rec)


def is_fatal(error):
    return error.endswith("fatal")


def is_kpi(error):
    return is_fatal(error) or error.startswith("Stalled")


def has_fatal(errors):
    return any([is_fatal(x) for x in errors])


def has_kpi(errors):
    return any([is_kpi(x) for x in errors])


def no_errors_chunks_count(x):
    return max(
        0, x["chunks_count"] - sum((x.get("errors_dict") or {}).values())
    )


def _chunks_count(x, predicate):
    dct = x.get("errors_dict") or {}
    errors = {e for e in dct if predicate(e)}
    return sum([dct[e] for e in errors])


def kpi_chunks_count(x):
    return _chunks_count(x, is_kpi)


def fatal_chunks_count(x):
    return _chunks_count(x, is_fatal)


def count_chunks(lst, error):
    if error == "_total_":
        return sum([x["chunks_count"] for x in lst])
    elif error == "_no_errors_":
        return sum([no_errors_chunks_count(x) for x in lst])
    elif error == "_all_errors_":
        return sum([sum((x.get("errors_dict") or {}).values()) for x in lst])
    elif error == "_fatal_":
        return sum([fatal_chunks_count(x) for x in lst])
    elif error == "_kpi_":
        return sum([kpi_chunks_count(x) for x in lst])
    else:
        return sum([x["errors_dict"].get(error, 0) for x in lst])


def ensure_utf8(strng):
    if isinstance(strng, str):
        return strng.decode("utf8", errors="replace")
    return strng


class SessionsMapper(object):
    def __init__(self, programs, date, top30, old_mode=False):
        self.programs = programs
        self.date = date
        self.top30 = top30
        self.old_mode = old_mode

    def _get_p(self, channel, content_id, content_video_title):
        if content_video_title:
            return content_video_title
        elif self.programs.get(channel, {}).get(content_id):
            return normalize_name(
                self.programs.get(channel, {}).get(content_id) or ""
            )
        return "UNKNOWN"

    def __call__(self, records):
        for rec in records:
            if rec["date"] != self.date:
                continue
            # cntr = defaultdict(lambda: Counter())
            # if rec['provider'] in self.top30:
            #     provider = rec['provider']
            # else:
            #     provider = 'UNKNOWN'
            provider = rec["provider"]
            view_channels = sorted(
                {x["view_channel"] for x in rec["view_session"]}
            )
            paths = sorted(
                {(x.get("path") or "") for x in rec["view_session"]}
            )
            testids = rec.get("slots_arr", [])
            if not testids:
                testids.append("_no_experiments_")
            testids.append("_total_")
            if view_channels != ["1tv"] and view_channels != ["Первый"]:
                view_channels.append("_nonfirst_")
            if any(x.startswith("Яндекс.") for x in view_channels):
                view_channels.append("_yandex_")
            if any(x.startswith("Яндекс.Новогодний") for x in view_channels):
                view_channels.append("Яндекс.Новогодний (все)")
            if any(x.startswith("Youtube.") for x in view_channels):
                view_channels.append("_youtube_")
            if any(x.startswith("Спецпроекты.") for x in view_channels):
                view_channels.append("_special_")
            if any(("2" in x.split(",")) for x in paths):
                view_channels.append("_tv_channels_")
            view_channels.append("_total_")
            vs = rec["view_session"]
            errors_all = rec.get("errors_all") or []
            ci = defaultdict(Counter)

            to_add = []
            for err in errors_all:
                if is_fatal(err["error_id"]) and (
                    int(err["timestamp"])
                    < int(rec["start_time"]) + int(rec["duration"])
                ):
                    err["error_id"] = err["error_id"][: -len("_fatal")]
                elif (
                    is_fatal(err["error_id"])
                    or err["error_id"].startswith("Stalled")
                ) and (
                    int(err["timestamp"])
                    >= int(rec["start_time"]) + int(rec["duration"])
                ):
                    err1 = err.copy()
                    err1["error_id"] = "_stalled_or_fatal_"
                    to_add.append(err1)
                ci[err["video_content_id"]][err["error_id"]] += 1
            for err in to_add:
                ci[err["video_content_id"]][err["error_id"]] += 1
            already_processed_errors = set()
            for x in vs:
                if (
                    x["content_id"] in ci
                    and x["content_id"] not in already_processed_errors
                ):
                    dct = ci[x["content_id"]]
                    x["errors_dict"] = dct
                    if isinstance(dct, dict):  # new style (2018-03-28)
                        x["errors"] = sorted(dct.keys())
                    elif isinstance(dct, basestring):  # fallback to old style
                        x["errors"] = [dct]
                    already_processed_errors.add(x["content_id"])
                else:
                    x["errors"] = ["_no_errors_"]
            for channel in view_channels:
                if channel == "_total_":
                    vs_ = vs
                elif channel == "_nonfirst_":
                    vs_ = [
                        x
                        for x in vs
                        if x["view_channel"] not in {"1tv", "Первый"}
                    ]
                elif channel == "Яндекс.Новогодний (все)":
                    vs_ = [
                        x
                        for x in vs
                        if x["view_channel"].startswith("Яндекс.Новогодний")
                    ]
                elif channel == "_yandex_":
                    vs_ = [
                        x
                        for x in vs
                        if x["view_channel"].startswith("Яндекс.")
                    ]
                elif channel == "_tv_channels_":
                    vs_ = [
                        x
                        for x in vs
                        if "2" in (x.get("path") or "").split(",")
                    ]
                elif channel == "_youtube_":
                    vs_ = [
                        x
                        for x in vs
                        if x["view_channel"].startswith("Youtube.")
                    ]
                elif channel == "_special_":
                    vs_ = [
                        x
                        for x in vs
                        if x["view_channel"].startswith("Спецпроекты.")
                    ]
                else:
                    vs_ = [x for x in vs if x["view_channel"] == channel]
                errors = set()
                for x in vs_:
                    errors |= set(x["errors"])
                errors = sorted(errors)
                if errors != ["_no_errors_"]:
                    errors.append("_all_errors_")
                if {x for x in errors if is_fatal(x)}:
                    errors.append("_fatal_")
                if {x for x in errors if is_kpi(x)}:
                    errors.append("_kpi_")
                errors.append("_total_")
                for error in errors:
                    if error == "_total_":
                        vs__ = vs_
                    elif error == "_fatal_":
                        vs__ = [x for x in vs_ if has_fatal(x["errors"])]
                    elif error == "_kpi_":
                        vs__ = [
                            x
                            for x in vs_
                            if has_fatal(x["errors"])
                            or "PlayedStalled" in x["errors"]
                        ]
                    elif error == "_all_errors_":
                        vs__ = [
                            x for x in vs_ if x["errors"] != ["_no_errors_"]
                        ]
                    else:
                        vs__ = [x for x in vs_ if error in x["errors"]]
                    programs = sorted(
                        {x.get("computed_program") or "UNKNOWN" for x in vs__}
                    )
                    programs.append("_total_")
                    for program in programs:
                        if program == "_total_":
                            vs___ = vs__
                        else:
                            vs___ = [
                                x
                                for x in vs__
                                if (x.get("computed_program") or "UNKNOWN")
                                == program
                            ]
                        view_types = sorted({x["view_type"] for x in vs___})
                        view_types.append("_total_")
                        for view_type in view_types:
                            if view_type == "_total_":
                                vs_4 = vs___
                            else:
                                vs_4 = [
                                    x
                                    for x in vs___
                                    if x["view_type"] == view_type
                                ]
                            tvt = sum([x["view_duration"] for x in vs_4])
                            lvt = sum([x["log_view_duration"] for x in vs_4])
                            chunks_count = count_chunks(vs_4, error)
                            if not self.old_mode:
                                bits_sent = (
                                    sum([x["bytes_sent"] for x in vs_4]) * 8
                                )
                                retrans = sum([x[RETRANS] for x in vs_4])
                            for comb in itertools.product(
                                (rec["browser_name"] or "unknown", "_total_"),
                                (rec["os_family"], "_total_"),
                                (rec["ref_from"], "_total_"),
                                (provider, "_total_"),
                            ):
                                ok = False
                                browser_name = comb[0]
                                os_family = comb[1]
                                ref_from = comb[2]
                                provider_ = comb[3]
                                non_total = [
                                    int(x != "_total_")
                                    for x in (
                                        comb[0],
                                        comb[1],
                                        comb[2],
                                        view_type,
                                        comb[3],
                                        channel,
                                        error,
                                        program,
                                    )
                                ]
                                if sum(non_total) <= 1:
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        error != "_total_"
                                        and provider != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        ref_from != "_total_"
                                        and provider_ != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        view_type != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        ref_from != "_total_"
                                        and view_type != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        error != "_total_"
                                        and view_type != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        os_family != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        os_family != "_total_"
                                        and view_type != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        os_family != "_total_"
                                        and error != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        os_family != "_total_"
                                        and ref_from != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        error != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        ref_from != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        program != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 3 and (
                                    (
                                        os_family != "_total_"
                                        and error != "_total_"
                                        and ref_from != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 3 and (
                                    (
                                        program != "_total_"
                                        and channel != "_total_"
                                        and view_type != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 3 and (
                                    (
                                        view_type != "_total_"
                                        and ref_from != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        view_type != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 3 and (
                                    (
                                        view_type != "_total_"
                                        and error != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        provider_ != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 3 and (
                                    (
                                        provider_ != "_total_"
                                        and error != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 2 and (
                                    (
                                        provider_ != "_total_"
                                        and os_family != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 3 and (
                                    (
                                        provider_ != "_total_"
                                        and error != "_total_"
                                        and os_family != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 3 and (
                                    (
                                        os_family != "_total_"
                                        and error != "_total_"
                                        and channel != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 3 and (
                                    (
                                        os_family != "_total_"
                                        and error != "_total_"
                                        and view_type != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 3 and (
                                    (
                                        os_family != "_total_"
                                        and channel != "_total_"
                                        and view_type != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 4 and (
                                    (
                                        os_family != "_total_"
                                        and error != "_total_"
                                        and channel != "_total_"
                                        and view_type != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 3 and (
                                    (
                                        os_family != "_total_"
                                        and browser_name != "_total_"
                                        and view_type != "_total_"
                                    )
                                ):
                                    ok = True
                                elif sum(non_total) == 4 and (
                                    (
                                        os_family != "_total_"
                                        and error != "_total_"
                                        and browser_name != "_total_"
                                        and view_type != "_total_"
                                    )
                                ):
                                    ok = True
                                if not ok:
                                    continue
                                vrs = dict(
                                    fielddate=rec["date"],
                                    browser=comb[0],
                                    os_family=comb[1],
                                    ref_from=comb[2],
                                    view_type=view_type,
                                    provider=comb[3],
                                    channel=channel,
                                    error=error,
                                    program=program,
                                    tvt=tvt,
                                    lvt=lvt,
                                    sessions=1,
                                    chunks_count=chunks_count,
                                    # bits_sent=bits_sent,
                                    yandexuid=rec["yandexuid"],
                                )
                                if not self.old_mode:
                                    vrs["bits_sent"] = bits_sent
                                    vrs[RETRANS] = retrans
                                yield Record(**vrs)


def extract_parsed_as(field):
    field = field or []
    if len(field) == 1:
        return field[0]
    else:
        return "UNKNOWN"


BAD_SHIT = [
    u"част[ьи](?=[^а-я])",
    u"эпизод(?=[^а-я])",
    u"сери[яи](?=[^а-я])",
    u"выпуск(?=[^а-я])",
    u"сезон(?=[^а-я])",
    u"[0-9]+ (январ|феврал|март|апрел|ма|июн|июл|август|сентябр|октябр|ноябр|декабр).? [0-9]+( года)?",
    u"[0-9]{1,2}\.[0-9]{1,2}\.[0-9]{4}",
    u"[0-9]{2}:[0-9]{2}",
]
NUMBER_PREFIX = u"([0-9]+(-(ы?й|[ая]?я))? [-–—и]? ?)?[0-9]+(-(ы?й|[ая]?я))? "
NUMBER_POSTFIX = u" [0-9]+(-(ы?й|[ая]?я))?"


def normalize_name(name):
    name = name or u""
    if not isinstance(name, unicode):
        name = name.decode("utf8", errors="replace")
    name = name.strip()
    name = name.replace(u"\ufffd", u"")
    name = re.sub(u"[^a-zа-яё .,0-9!]", u"", name, flags=(re.I | re.U))
    name = re.sub(u'["«»“”]', u"", name)
    for bad_shit in BAD_SHIT:
        name = re.sub(NUMBER_PREFIX + bad_shit, u"", name, flags=(re.I | re.U))
        name = re.sub(
            bad_shit + NUMBER_POSTFIX, u"", name, flags=(re.I | re.U)
        )
        name = re.sub(bad_shit, u"", name, flags=(re.I | re.U))
    name = truncate_string(name)
    name = intelligent_strip(name)
    name = till_first_punctuation(name)
    return name or u"BAD_NAME"


def intelligent_strip(name):
    alphanums = list(re.finditer(u"[а-яa-zё]", name, flags=(re.I | re.U)))
    if not alphanums:
        return u"BAD_NAME"
    return name[alphanums[0].span()[0] : alphanums[-1].span()[0] + 1]


def truncate_string(st_, thr=30):
    if len(st_) <= thr:
        return st_
    sp = st_.split()
    for x in list(range(len(sp)))[::-1]:
        jnd = u" ".join(sp[:x])
        if len(jnd) <= thr:
            return jnd
    return u"{}...".format(st_[: thr - 3])


def till_first_punctuation(name):
    preps = u"[,\.\?!]"
    for x in re.finditer(preps, name):
        return name[: x.span()[0]]
    return name


def make_cube_report(date):
    chunks_table = "{}/{}/chunks".format(job_root, date)
    sessions_table = "{}/{}/sessions".format(job_root, date)
    programs_table = "{}/{}/programs".format(job_root, date)
    prov_stats_table = "{}/{}/providers_stats".format(job_root, date)
    cube_report_table = "{}/{}/cube_report".format(job_root, date)
    errors_table_for_join = "$job_root/%s/errors_for_join" % date

    if g_args.wait:
        while (
            not cluster.driver.client.exists(sessions_table)
            or not cluster.driver.client.get_attribute(
                sessions_table, "finished_time", ""
            )
            or get_date(
                cluster.driver.client.get_attribute(
                    sessions_table, "finished_time", ""
                ).split("T")[0]
            )
            < g_args.wait
        ):
            print("waiting for {}...".format(sessions_table))
            time.sleep(3600)

    try:
        programs = read_programs_table(programs_table)
    except Exception as e:
        programs = {}
        print("error reading programs table: {}".format(e))

    # job = cluster.job()
    # job.table(
    #     chunks_table
    # ).project(
    #     as_parsed=ne.custom(
    #         extract_parsed_as, 'as_parsed'
    #     )
    # ).groupby('as_parsed').aggregate(
    #     count=na.count()
    # ).sort(
    #     'count'
    # ).put(
    #     prov_stats_table
    # )
    # job.run()

    # rc = cluster.driver.client.get_attribute(prov_stats_table, 'row_count')
    # top30 = list(
    #     itertools.islice(cluster.read(prov_stats_table), rc - 31, None)
    # )
    # top30 = [x for x in top30 if x['as_parsed'] != "UNKNOWN"][-30:]
    # top30 = {x['as_parsed'] for x in top30}

    job = cluster.job()
    sessions = job.table(sessions_table)

    key_fields = [
        "fielddate",
        "browser",
        "os_family",
        "ref_from",
        "view_type",
        "provider",
        "channel",
        "error",
        "program",
    ]

    aggregators = dict(
        tvt=na.sum("tvt"),
        lvt=na.sum("lvt"),
        sessions=na.sum("sessions"),
        chunks_count=na.sum("chunks_count"),
    )
    if not g_args.old_mode:
        aggregators["bits_sent"] = na.sum("bits_sent")
        aggregators[RETRANS] = na.sum(RETRANS)

    sessions = (
        sessions.map(
            SessionsMapper(
                programs, format(date), set(), old_mode=g_args.old_mode
            ),
            intensity="ultra_cpu",
        )
        .groupby(*key_fields)
        .aggregate(**aggregators)
        .groupby(*[x for x in key_fields if x != "error"])
        .reduce(AddShares(old_mode=g_args.old_mode))
        .sort("fielddate")
        .put(cube_report_table)
    )

    job.run()

    now = datetime.datetime.now()
    cluster.driver.client.set_attribute(
        cube_report_table, "finished_time", now.strftime("%Y-%m-%dT%H:%M:%S")
    )


def get_date(s):
    try:
        return datetime.datetime.strptime(
            re.search(r"[0-9]{4}-[0-9]{2}-[0-9]{2}", s).group(0), "%Y-%m-%d"
        ).date()
    except (ValueError, TypeError, AttributeError):
        return


def main():
    global cluster
    global g_args
    global job_root
    parser = argparse.ArgumentParser()
    parser.add_argument("--pool")
    parser.add_argument("--parallel_operations_limit", type=int, default=10)
    parser.add_argument("--from", default=None)
    parser.add_argument("--to", default=None)
    parser.add_argument("--wait", default=None)
    parser.add_argument("--job_root", default="//home/videolog/strm_video")
    parser.add_argument("--old_mode", action="store_true")
    args = parser.parse_args()

    from_ = getattr(args, "from")
    to_ = getattr(args, "to")
    job_root = args.job_root
    if args.wait:
        args.wait = get_date(args.wait)
    g_args = args

    cluster = clusters.yt.Hahn(
        token=os.environ["YT_TOKEN"], pool=args.pool
    ).env(
        templates=dict(job_root=job_root),
        parallel_operations_limit=args.parallel_operations_limit,
    )

    if from_ and to_:
        dates_to_process = date_range(from_, to_)
    else:
        processed_dates = sorted(
            get_date(s)
            for s in cluster.driver.client.search(
                root=job_root,
                node_type="table",
                path_filter=(lambda x: x.endswith("/cube_report")),
            )
            if get_date(s)
            and cluster.driver.client.get_attribute(s, "row_count", 0) > 0
        )

        if processed_dates:
            print("last date: {}".format(processed_dates[-1]))
        else:
            print("no last date")

        available_dates = sorted(
            get_date(s)
            for s in cluster.driver.client.search(
                root=job_root,
                node_type="table",
                path_filter=(lambda x: x.endswith("/sessions")),
            )
            if get_date(s) and get_date(s) > datetime.date(2018, 3, 14)
        )

        print("last available date: {}".format(available_dates[-1]))

        if processed_dates:
            dates_to_process = [
                x for x in available_dates if x > processed_dates[-1]
            ]
        else:
            dates_to_process = available_dates

    print("dates to process: {}".format(dates_to_process))

    for date in dates_to_process:
        print("running for {}".format(date))
        make_cube_report(date=format(date))


if __name__ == "__main__":
    main()
