from functools import partial

import luigi
import yt.wrapper as yt

from crypta.graph.v1.python.utils import utils
from crypta.graph.v1.python.lib.luigi import yt_luigi
from crypta.graph.v1.python.rtcconf import config
from crypta.graph.v1.python.utils import yt_clients


THRESHOLD = 0.919


def email_to_login(rec):
    yield {"id_value": rec["login"], "organization": rec["organization"]}


def set_orgs_for_logins(_, recs):
    is_organization = 0
    for rec in recs:
        if rec["@table_index"] == 0:
            is_organization = rec["organization"]
        else:
            rec["organization"] = is_organization
            rec["@table_index"] = 0
            yield rec


def login_to_email(rec):
    orig_login_col = config.ID_TYPE_LOGIN + "_" + config.ID_SOURCE_TYPE_FP + "_orig"
    orig_login = rec[orig_login_col][rec["id_value"]]
    email = utils.login_to_email(orig_login)
    yield {"id_value": email, "login": rec["id_value"]}


def reduce_classify_org_emails(_, recs, threshold):
    is_organization = 0
    for rec in recs:
        if rec["@table_index"] == 0:
            prediction = rec["is_org_score"]
            if prediction > threshold:
                is_organization = 1
        else:
            rec["organization"] = is_organization
            rec["@table_index"] = 0
            yield rec


def classify_org_emails(source_folder, classification_result, target_folder, from_login):
    yt_client = yt_clients.get_yt_client()
    if from_login:
        source = config.ID_TYPE_LOGIN + "_" + config.ID_SOURCE_TYPE_FP
        tmp_folder_login_to_email = target_folder + "tmp_login_to_email"
        yt_client.run_map(login_to_email, source_folder + "yuid_with_id_" + source, tmp_folder_login_to_email)
        yt_client.run_sort(tmp_folder_login_to_email, sort_by="id_value")
        source_table = tmp_folder_login_to_email
    else:
        source = config.ID_TYPE_EMAIL
        source_table = source_folder + "yuid_with_id_" + source
    with yt_client.Transaction():
        # can't map directly to yuid_with_id_X, it breaks sorting in schema
        tmp_table = target_folder + "yuid_with_id_" + source + "_tmp"
        yt_client.run_reduce(
            partial(reduce_classify_org_emails, threshold=THRESHOLD),
            ["<rename_columns={id=id_value}>" + classification_result, source_table],
            tmp_table,
            reduce_by="id_value",
        )
        if from_login:
            yt_client.run_map(email_to_login, tmp_table, tmp_table)
            yt_client.run_sort(tmp_table, sort_by="id_value")
            yt_client.run_reduce(
                set_orgs_for_logins,
                [tmp_table, source_folder + "yuid_with_id_" + source],
                tmp_table,
                reduce_by="id_value",
            )
            yt_client.remove(tmp_folder_login_to_email)
        yt_client.run_sort(tmp_table, target_folder + "yuid_with_id_" + source, sort_by="id_value")
        yt_client.remove(tmp_table)


class OrgEmailsClassifyTask(yt_luigi.BaseYtTask):
    # TODO: make this task transactional so that it doesn't break yuid_with_all dict sorting
    date = luigi.Parameter()
    tags = ["v1"]

    def input_folders(self):
        return {
            "dict_f": config.GRAPH_YT_DICTS_FOLDER,
            "classification_result": yt.ypath_join(config.CRYPTA_IDS_STORAGE, "email", "email_organization_score"),
        }

    def output_folders(self):
        return {"dict_f": config.GRAPH_YT_DICTS_FOLDER}

    def requires(self):
        from crypta.graph.v1.python.matching.yuid_matching import graph_dict

        return [graph_dict.YuidAllIdBySourceDictsTask(self.date)]

    def run(self):
        classify_org_emails(
            self.in_f("dict_f"), self.in_f("classification_result"), self.out_f("dict_f"), from_login=False
        )
        classify_org_emails(
            self.in_f("dict_f"), self.in_f("classification_result"), self.out_f("dict_f"), from_login=True
        )

    def output(self):
        dict_folder = self.out_f("dict_f")
        email_table = dict_folder + "yuid_with_id_" + config.ID_TYPE_EMAIL
        login_table = dict_folder + "yuid_with_id_" + config.ID_TYPE_LOGIN + "_" + config.ID_SOURCE_TYPE_FP
        return [yt_luigi.YtDateColumnTarget(table, "organization", self.date) for table in [email_table, login_table]]
