from collections import defaultdict
from yt.yson.yson_types import YsonEntity
from yt import yson


def wrap_yson_obj(obj):
    if isinstance(obj, dict):
        return {wrap_yson_obj(k): wrap_yson_obj(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [wrap_yson_obj(x) for x in obj]
    if isinstance(obj, YsonEntity):
        return
    if isinstance(obj, bytes):
        return obj.decode("utf8", errors="replace")
    return obj


class AdLoadCounter:
    dict_awaps = {
        0: "show",
        1: "click",
        52: "start",
        53: "midpoint",
        54: "first-quartile",
        55: "third-quartile",
        56: "complete",
        57: "mute",
        58: "unmute",
        62: "close",
        63: "skip",
    }

    def __init__(self, key):
        self.key = key
        self.by_suffix = defaultdict(
            lambda: dict(
                adv_start=0.0,
                adv_skip=0.0,
                total_adv_time=0.0,
                leaves=0,
                view_time_non_muted=0.0,
                is_video=False,
            )
        )

    @staticmethod
    def count_tat_increment(roll):
        start = roll["start"]
        skip = roll["skip"]
        complete = roll["complete"]
        if start and not skip and not complete:
            return 25
        diff = max(max(skip, complete) - start, 0)
        return min(diff, 25)

    def __call__(self, rec):
        if not (rec.price or 0) and not (rec.view_time_non_muted or 0):
            return
        suffices = ["dummy"]
        for suffix in suffices:
            self.by_suffix[suffix]["is_video"] = True
            self.by_suffix[suffix][
                "view_time_non_muted"
            ] += rec.view_time_non_muted
        if rec.ad_tracking_events is None:
            return
        ad_tracking_events = wrap_yson_obj(yson.loads(rec.ad_tracking_events))
        rolls = {}
        last_adv_event = 0.0
        total_adv_time = 0
        ts = rec.timestamp
        for roll in ad_tracking_events or []:
            roll_id = (
                str(roll.get("global_request_id"))
                if roll.get("source") == "awaps"
                else str(roll.get("campaignid")) + str(roll.get("bidid"))
            )
            if roll_id not in rolls.keys():
                rolls[roll_id] = {"start": 0.0, "skip": 0.0, "complete": 0.0}
            cur_roll = rolls[roll_id]
            action_id = None
            if (
                roll.get("source") == "awaps"
                and roll.get("actionid") in self.dict_awaps.keys()
            ):
                action_id = self.dict_awaps[roll.get("actionid")]
            if roll.get("source") == "chtracking":
                action_id = roll.get("action")
            if action_id is None:
                continue
            if action_id == "start" and (
                cur_roll["start"] == 0
                or cur_roll["start"] > roll.get("timestamp")
            ):
                cur_roll["start"] = roll.get("timestamp")
                for suffix in suffices:
                    self.by_suffix[suffix]["adv_start"] += 1
            elif action_id == "skip" and (
                cur_roll["skip"] == 0
                or cur_roll["skip"] > roll.get("timestamp")
            ):
                cur_roll["skip"] = roll.get("timestamp")
                for suffix in suffices:
                    self.by_suffix[suffix]["adv_skip"] += 1
            elif action_id == "complete" and (
                cur_roll["complete"] == 0
                or cur_roll["complete"] > roll.get("timestamp")
            ):
                cur_roll["complete"] = roll.get("timestamp")
            last_adv_event = max(last_adv_event, roll.get("timestamp"))
        for roll_id in rolls.keys():
            tat_increment = self.count_tat_increment(rolls[roll_id])
            total_adv_time += tat_increment
            for suffix in suffices:
                self.by_suffix[suffix]["total_adv_time"] += tat_increment
        leaves_increment = int(
            ts + rec.view_time + total_adv_time - last_adv_event < 60
        )
        for suffix in suffices:
            self.by_suffix[suffix]["leaves"] += leaves_increment

    def calc_metrics(self):
        result = []
        bs = self.by_suffix
        for key in bs:
            if key != "dummy":
                continue
            if not bs[key]["is_video"]:
                continue
            if not bs[key]["adv_start"]:
                metric = 0
                metric_wo_leaves = 0
            else:
                metric_wo_leaves = (
                    (1 + bs[key]["adv_skip"])
                    / bs[key]["adv_start"]
                    * bs[key]["total_adv_time"]
                )
                metric = metric_wo_leaves + 10 * bs[key]["leaves"]
            view_time_non_muted = bs[key]["view_time_non_muted"]
            metric_nm = None
            metric_wo_leaves_nm = None
            if view_time_non_muted:
                metric_nm = metric / view_time_non_muted
                metric_wo_leaves_nm = metric_wo_leaves / view_time_non_muted
            result.append(
                {
                    "fielddate": self.key[1],
                    "ref_from": self.key[2],
                    "platform": self.key[3],
                    "channel": self.key[4],
                    "ad_load": metric,
                    "ad_load_per_vtnm": metric_nm,
                    "ad_load_wo_leaves": metric_wo_leaves,
                    "ad_load_wo_leaves_per_vtnm": metric_wo_leaves_nm,
                    "total_adv_time": bs[key]["total_adv_time"],
                }
            )
        return result


def ad_load_counter(key, recs):
    alc = AdLoadCounter(key)
    for rec in recs:
        alc(rec)
    for metric in alc.calc_metrics():
        yield metric
