from collections import defaultdict
import math
import datetime
import json
from yt import yson
from yt.yson.yson_types import YsonEntity


def get_resolution_cat(h):
    if h <= 360:
        return "ld"
    if h < 720:
        return "sd"
    if h < 1080:
        return "hd"
    if h < 2160:
        return "fhd"
    if h >= 2160:
        return "4k"
    return "other"


def get_resolution(k, v):
    v.update({"resolution": k})
    return v


def is_(x):
    return x is not None and not isinstance(x, YsonEntity)


def round_ts(ts):
    to_secs = int(ts / 1000.0)
    rounded = to_secs - to_secs % 3600
    return rounded


def wrap_time(time_):
    if not time_ or time_ < 0 or time_ > 86400:
        return 0.0
    return time_


class StructWrapper:
    def __init__(self, obj, dct):
        self.obj = obj
        self.dct = dct

    def __getattr__(self, name):
        try:
            return self.dct[name]
        except KeyError:
            return getattr(self.obj, name)


def generate_resolutions_values(dct):
    result = {}
    for res in ("ld", "sd", "hd", "fhd", "4k"):
        result[f"p_{res}_watched_time"] = dct.get(res, {}).get("watchedTime")
        result[f"p_{res}_stalled_time"] = dct.get(res, {}).get("stalledTime")
        result[f"p_{res}_stalled_count"] = dct.get(res, {}).get("stalledCount")
    return result


def get_stalled_init_length(stalled_infos):
    id_ = None
    if not stalled_infos:
        return 0
    for info in stalled_infos:
        if id_ is None and info.reason == b"Init" and info.id:
            id_ = info.id
            continue
        if info.ev == b"StalledEnd" and id_ is not None and info.id == id_:
            return info.duration or 0
    inits = [
        info.duration
        for info in stalled_infos
        if info.reason == b"Init" and info.duration
    ]
    if inits:
        return max(inits) or 0
    return 0


def get_start_time(event_infos, stalled_infos):
    stalled_init_length = get_stalled_init_length(stalled_infos)
    cp = [x.ts for x in (event_infos or []) if x.ev == b"CreatePlayer"]
    st = [x.ts for x in (event_infos or []) if x.ev == b"Start"]
    if cp and st:
        start_time = max((st[0] - cp[0]) / 1000.0, 0)
    else:
        start_time = 0
    return start_time + stalled_init_length


def s(bytes_):
    if isinstance(bytes_, bytes):
        return bytes_.decode("utf8", errors="replace")
    if not bytes_:
        return ""
    return bytes_


def get_total_values(dctx, val):
    result = 0
    for key in dctx:
        dct = dctx[key]
        for res in dct["resolutions"]:
            result += dct["resolutions"][res].get(val) or 0
    return result


def mapper(rec):
    pl_output = process_list(rec.unprocessed_states, mode="new")
    if not pl_output:
        return
    tvt = get_total_values(pl_output, "watchedTime")
    tst = get_total_values(pl_output, "stalledTime")
    tsc = get_total_values(pl_output, "stalledCount")
    start_time = get_start_time(rec.event_infos, rec.stalled_infos)
    len_category = "long" if tvt >= 600.0 else "short"
    fatal = False
    color_reason = ""
    if rec.fatal:
        color = "red"
        fatal = True
        color_reason = "red_has_fatal"
    else:
        start_time = get_start_time(rec.event_infos, rec.stalled_infos)
        start_time_threshold = (
            (10, 20)
            if "Windows" in s(rec.platform) or "tv" in s(rec.platform)
            else (30, 50)
        )
        if start_time <= start_time_threshold[0] and tsc <= 2 and tst <= 5:
            color = "green"
            color_reason = "green_thresholds"
        elif start_time <= start_time_threshold[1] and tsc <= 5 and tst <= 10:
            color = "yellow"
            color_reason = "yellow_thresholds"
        else:
            color = "red"
            color_reason = "red_thresholds"
        # color = "color:{},start_time:{},start_time_threshold:{},tsc:{},tst:{},pl_output:{}".format(
        #     color, start_time, start_time_threshold, tsc, tst, json.dumps(pl_output)
        # )
    for key in pl_output:
        value = pl_output[key]
        result = generate_resolutions_values(value["resolutions"])
        result["color"] = color
        result["color_reason"] = color_reason
        result["start_time"] = start_time
        result["tvt"] = tvt
        result["tsc"] = tsc
        result["tst"] = tst
        result["len_category"] = len_category
        result["timestamp"] = value["timestamp"]
        result["avglogs"] = value["avglogs"]
        result["muted_cat"] = key
        yield StructWrapper(rec, result)


def process_list(list_, mode="old"):
    dd = defaultdict(
        lambda: defaultdict(
            lambda: {"watchedTime": 0, "stalledCount": 0, "stalledTime": 0}
        )
    )
    avglogs = defaultdict(list)
    prev = None
    for rec in list_:
        if not rec or not rec.timestamp:
            continue
        if rec.is_ad:
            prev = rec
            continue
        if (
            rec.watchedTime is None
            or rec.stalledCount is None
            or rec.stalledTime is None
        ):
            continue
        if prev is None:
            wt = wrap_time(rec.watchedTime)
            sc = rec.stalledCount
            st = wrap_time(rec.stalledTime)
        else:
            wt = wrap_time(rec.watchedTime) - wrap_time(prev.watchedTime)
            sc = rec.stalledCount - prev.stalledCount
            st = wrap_time(rec.stalledTime) - wrap_time(prev.stalledTime)
        resolution = rec.height
        if rec.isMuted == True:
            muted_cat = "muted"
        elif rec.isMuted == False:
            muted_cat = "non_muted"
        else:
            muted_cat = "unknown"
        if is_(resolution):
            res_cat = get_resolution_cat(resolution)
            for mc_ in (muted_cat, "_total_"):
                if wt >= 0:
                    dd[mc_][res_cat]["watchedTime"] += wt
                if st >= 0:
                    dd[mc_][res_cat]["stalledTime"] += st
                if sc >= 0:
                    dd[mc_][res_cat]["stalledCount"] += sc
        if (
            is_(rec.capHeight)
            and is_(rec.height)
            and (rec.capHeight - rec.height) >= 0
        ):
            for mc_ in (muted_cat, "_total_"):
                try:
                    avglogs[mc_].append(
                        math.log(max(rec.height, rec.capHeight) + 1)
                        - math.log(rec.height + 1)
                    )
                except ValueError:
                    pass
        prev = rec
    if not prev or not prev.timestamp:
        return
    result = {}
    keys = set(dd.keys()) | set(avglogs.keys())
    if mode == "old" and "_total_" not in keys:
        return
    for key in keys:
        dd_ = dd[key]
        avglogs_ = avglogs[key]
        dct = {
            "avglogs": avglogs_[:2000],
            "timestamp": round_ts(prev.timestamp),
            "resolutions": dd_,
        }
        if mode == "old" and key == "_total_":
            dct["resolutions"] = yson.dumps(
                [get_resolution(k, v) for k, v in dd_.items()]
            )
            return dct
        else:
            result[key] = dct
    return result
