import luigi

from crypta.graph.v1.python.lib.luigi import yt_luigi
from crypta.graph.v1.python.rtcconf import config
from crypta.graph.v1.python.utils import mr_utils as mr
from crypta.graph.v1.python.utils import utils


def map_pairs(rec):
    yuid1, yuid2 = rec["key"].split("_")
    yield {"id_value": rec["id_value"], "yuid1": yuid1, "yuid2": yuid2}


def reduce_short_sessions_pairs(login_key, recs):
    try:
        short_session_recs, pairs = mr.split_left_right(recs, oom_limit=1000)
    except mr.OomLimitException as oom:
        err_out = {"recs_count": oom.recs_count, "@table_index": 2}
        err_out.update(login_key)
        yield err_out
        return

    if short_session_recs:
        short_session_yuids = set(r["yuid"] for r in short_session_recs)
        for pair_rec in pairs:
            if pair_rec["yuid1"] in short_session_yuids or pair_rec["yuid2"] in short_session_yuids:
                pair_rec["short_session_yuids"] = list(short_session_yuids)
                pair_rec["@table_index"] = 0
                yield pair_rec
            else:
                pair_rec["short_session_yuids"] = list(short_session_yuids)
                pair_rec["@table_index"] = 1
                yield pair_rec


class LoginShortSessionPairsStat(yt_luigi.BaseYtTask):
    date = luigi.Parameter()
    tags = ["v1"]

    def input_folders(self):
        return {"graph": config.YT_OUTPUT_FOLDER, "pairs": config.YT_OUTPUT_FOLDER + self.date + "/pairs/"}

    def output_folders(self):
        return {"passport": config.YT_OUTPUT_FOLDER + self.date + "/passport/"}

    def requires(self):
        from crypta.graph.v1.python.matching.pairs import graph_pairs
        from crypta.graph.v1.python.data_imports.import_logs.graph_passport import ImportPassportLogDayTask

        day_tasks = [
            ImportPassportLogDayTask(date=dt, run_date=self.date)
            for dt in utils.get_dates_before(self.date, int(config.STORE_DAYS))
        ]
        return [graph_pairs.GraphPairs(self.date)] + day_tasks

    def before_run(self):
        mr.mkdir(self.out_f("passport"))

    def run(self):
        out_passport_f = self.out_f("passport")

        # extract from statbox for today
        login_pairs = self.in_f("pairs") + "yuid_pairs_" + config.ID_TYPE_LOGIN
        short_sessions_day_tables = mr.get_date_tables(
            self.in_f("graph"), "passport/short_session_raw", int(config.STORE_DAYS), before_date=self.date
        )

        utils.wait_all(
            [
                self.yt.run_map(map_pairs, login_pairs, out_passport_f + "login_pairs", sync=False),
                mr.distinct_by(
                    ["yuid", "puid", "id_value"],
                    short_sessions_day_tables,
                    out_passport_f + "short_session_month",
                    sync=False,
                ),
            ]
        )

        mr.sort_all([out_passport_f + "login_pairs", out_passport_f + "short_session_month"], sort_by="id_value")

        self.yt.run_reduce(
            reduce_short_sessions_pairs,
            [out_passport_f + "short_session_month", out_passport_f + "login_pairs"],
            [
                out_passport_f + "pairs_short_session",
                out_passport_f + "pairs_short_session_not_matched",
                out_passport_f + "oom",
            ],
            reduce_by="id_value",
        )

        mr.drop(out_passport_f + "login_pairs")

    def output(self):
        return [
            yt_luigi.YtTarget(self.out_f("passport") + "pairs_short_session"),
            yt_luigi.YtTarget(self.out_f("passport") + "pairs_short_session_not_matched"),
        ]
