from collections import defaultdict, Counter
import math
import json
import datetime
from yt import yson
from yt.yson.yson_types import YsonEntity

QUALITIES = [
    ["highres", 4320],
    ["hd2880", 2880],
    ["hd2160", 2160],
    ["hd1440", 1440],
    ["hd1080", 1080],
    ["hd720", 720],
    ["large576", 576],
    ["large", 480],
    ["medium", 360],
    ["small", 240],
    ["tiny", 144],
]

QS_REVERSE = {x[1]: x[0] for x in QUALITIES}


def get_resolution_cat(minSide):
    closest = sorted(QS_REVERSE.keys(), key=lambda x: abs(x - minSide))[0]
    return QS_REVERSE[closest]


def get_resolution(k, v):
    v.update({"resolution": k})
    return v


def is_(x):
    return x is not None and not isinstance(x, YsonEntity)


def round_ts(ts):
    to_secs = int(ts / 1000.0)
    rounded = to_secs - to_secs % 3600
    return rounded


def wrap_time(time_):
    if not time_ or time_ < 0 or time_ > 86400:
        return 0.0
    return time_


class StructWrapper:
    def __init__(self, obj, dct):
        self.obj = obj
        self.dct = dct

    def __getattr__(self, name):
        try:
            return self.dct[name]
        except KeyError:
            return getattr(self.obj, name, None)


def generate_resolutions_values(dct):
    result = {}
    for res in ("ld", "sd", "hd", "fhd"):
        result[f"p_{res}_watched_time"] = dct.get(res, {}).get("watchedTime")
        result[f"p_{res}_stalled_time"] = dct.get(res, {}).get("stalledTime")
        result[f"p_{res}_stalled_count"] = dct.get(res, {}).get("stalledCount")
    return result


STALLED_TYPES = [
    "_total_",
    "Init",
    "SetSource",
    "AdEnd",
    "Seek",
    "VideoTrackChange",
    "Recover",
    "MediaError",
    "Offline",
    "Other",
    "RepresentationsChange",
    "LiveEdge",
]

TO_LOWER = {x.lower(): x for x in STALLED_TYPES}


def process_reason(reason):
    if not reason:
        return
    processed = reason.lower().replace("_", "")
    if processed in TO_LOWER:
        return TO_LOWER[processed]
    return reason


def convert_stalled_info(stalled_info_dict):
    return [
        {
            "stalledReason": k,
            "stalledCount": stalled_info_dict[k].get("stalledCount") or 0,
            "stalledTime": stalled_info_dict[k].get("stalledTime") or 0,
        }
        for k in STALLED_TYPES
    ]


def convert_dict(d):
    if isinstance(d, dict):
        return {convert_dict(k): convert_dict(v) for k, v in d.items()}
    if isinstance(d, list):
        return [convert_dict(x) for x in d]
    return d


def mapper_old(rec):
    pl_output = process_list(rec.unprocessed_states, mode="new")
    if not pl_output:
        return
    for key in pl_output:
        value = pl_output[key]
        result = {}
        result["timestamp"] = value["timestamp"]
        result["avglogs"] = None
        result["muted_cat"] = key
        for res in value["resolutions"]:
            dct = result.copy()
            dct["resolution"] = res
            if res == "_total_":
                dct["avglogs"] = value["avglogs"]
            dct.update(value["resolutions"][res])
            dct["stalledInfo"] = convert_stalled_info(value["stalledInfo"][res])
            dct["stalledInfoRaw"] = json.dumps(convert_dict(value["stalledInfo"]))
            yield StructWrapper(rec, dct)


def mapper(unprocessed_states):
    pl_output = process_list_v2(unprocessed_states, mode="new")
    if not pl_output or len(pl_output) != 3:
        return
    resolution_info, stalled_info, first_avglog_info = pl_output
    resolution_info_transformed = []
    stalled_info_transformed = []
    for key in resolution_info:
        dct = resolution_info[key]
        dct["muted_cat"] = key[0]
        dct["res_cat"] = key[1]
        resolution_info_transformed.append(dct)
    for key in stalled_info:
        dct = stalled_info[key]
        dct["stalled_reason"] = key
        stalled_info_transformed.append(dct)
    return (resolution_info_transformed, stalled_info_transformed, first_avglog_info)


def try_get_resolution(rec, cap=False):
    height_attr = "capHeight" if cap else "height"
    width_attr = "capWidth" if cap else "width"
    height = getattr(rec, height_attr)
    width = getattr(rec, width_attr)
    if is_(height) and is_(width):
        return min(height, width)
    elif is_(height):
        return height
    elif is_(width):
        return width
    else:
        return


def process_list(list_, mode="old"):
    dd = defaultdict(lambda: defaultdict(Counter))
    stalled_info = defaultdict(lambda: defaultdict(lambda: defaultdict(Counter)))
    avglogs = defaultdict(list)
    prev = None
    for rec in list_:
        if not rec or not rec.timestamp:
            continue
        if rec.is_ad:
            prev = rec
            continue
        if (
            rec.watchedTime is None
            or rec.stalledCount is None
            or rec.stalledTime is None
        ):
            continue
        if prev is None:
            wt = wrap_time(rec.watchedTime)
            sc = rec.stalledCount
            st = wrap_time(rec.stalledTime)
        else:
            wt = wrap_time(rec.watchedTime) - wrap_time(prev.watchedTime)
            sc = rec.stalledCount - prev.stalledCount
            st = wrap_time(rec.stalledTime) - wrap_time(prev.stalledTime)
        resolution = try_get_resolution(rec)
        if rec.isMuted == True:
            muted_cat = "muted"
        elif rec.isMuted == False:
            muted_cat = "non_muted"
        else:
            muted_cat = "unknown"
        if resolution:
            res_cat_ = get_resolution_cat(resolution)
            for res_cat in (res_cat_, "_total_"):
                for mc_ in (muted_cat, "_total_"):
                    if wt >= 0:
                        dd[mc_][res_cat]["watchedTime"] += wt
                    if not prev:
                        reason = "Init"
                    else:
                        reason = (
                            process_reason(
                                (prev.stalledReason or b"Unknown").decode(
                                    "utf8", errors="replace"
                                )
                            )
                            or "Unknown"
                        )
                    if st >= 0:
                        for reason_ in [reason, "_total_"]:
                            stalled_info[mc_][res_cat][reason_]["stalledTime"] += st
                    if sc >= 0:
                        for reason_ in [reason, "_total_"]:
                            stalled_info[mc_][res_cat][reason_]["stalledCount"] += sc
        cap_resolution = try_get_resolution(rec, cap=True)
        if resolution and cap_resolution:
            if resolution <= cap_resolution:
                avglog = 0
            else:
                avglog = math.log(resolution + 1) - math.log(cap_resolution + 1)
            for mc_ in (muted_cat, "_total_"):
                avglogs[mc_].append(avglog)
        prev = rec
    if not prev or not prev.timestamp:
        return
    result = {}
    keys = set(dd.keys()) | set(avglogs.keys())
    if mode == "old" and "_total_" not in keys:
        return
    for key in keys:
        dd_ = dd[key]
        avglogs_ = avglogs[key]
        dct = {
            "avglogs": avglogs_[:2000],
            "timestamp": round_ts(prev.timestamp),
            "resolutions": dd_,
            "stalledInfo": stalled_info[key],
        }
        if mode == "old" and key == "_total_":
            dct["resolutions"] = yson.dumps(
                [get_resolution(k, v) for k, v in dd_.items()]
            )
            return dct
        else:
            result[key] = dct
    return result


# Отчет для muted и resolution

#     muted x resolution x bandwidth_cat
#     count, tvt, avglog*

# Отчет для stalled

#     ref_from x platform x bandwidth_cat x stalled_reason
#     count, tvt, stalled_time_*, stalled_count_*


def process_list_v2(list_, mode="old"):
    resolution_info = defaultdict(lambda: {"tvt": 0, "avglogs": []})
    stalled_info = defaultdict(lambda: {"stalled_count": 0, "stalled_time": 0})
    prev = None
    first_avglog_info = None
    avglog = None
    for rec in list_:
        if not rec or not rec.timestamp:
            continue
        if rec.is_ad or not rec.isVisible:
            prev = rec
            continue
        if (
            rec.watchedTime is None
            or rec.stalledCount is None
            or rec.stalledTime is None
        ):
            continue
        if prev is None:
            wt = wrap_time(rec.watchedTime)
            sc = rec.stalledCount
            st = wrap_time(rec.stalledTime)
        else:
            wt = wrap_time(rec.watchedTime) - wrap_time(prev.watchedTime)
            sc = rec.stalledCount - prev.stalledCount
            st = wrap_time(rec.stalledTime) - wrap_time(prev.stalledTime)
        resolution = try_get_resolution(rec)
        if rec.isMuted == True:
            muted_cat = "muted"
        elif rec.isMuted == False:
            muted_cat = "non_muted"
        else:
            muted_cat = "unknown"
        if resolution:
            res_cat = get_resolution_cat(resolution)
        else:
            res_cat = "unknown"
        if not prev:
            reason = "Init"
        else:
            reason = (
                process_reason(
                    (prev.stalledReason or b"Unknown").decode("utf8", errors="replace")
                )
                or "Unknown"
            )
        cap_resolution = try_get_resolution(rec, cap=True)
        if resolution and cap_resolution:
            if resolution <= cap_resolution:
                avglog = 0
            else:
                avglog = math.log(resolution + 1) - math.log(cap_resolution + 1)
            resolution_info[(muted_cat, res_cat)]["avglogs"].append(avglog)
        if wt > 0:
            resolution_info[(muted_cat, res_cat)]["tvt"] += wt
        if st > 0:
            stalled_info[reason]["stalled_time"] += st
        if sc > 0:
            stalled_info[reason]["stalled_count"] += sc
        if wt and avglog is not None and not first_avglog_info:
            first_avglog_info = {
                "res_cat": res_cat,
                "avglog": avglog,
            }
        prev = rec
    if not prev or not prev.timestamp:
        return
    return (resolution_info, stalled_info, first_avglog_info)
