from crypta.graph.fuzzy.lib.luiger import BaseTask, DateParameter
from yt.wrapper import with_context as yt_with_context, OperationsTracker, common as yt_common, create_table_switch

import crypta.graph.fuzzy.lib.config as conf
import crypta.graph.fuzzy.lib.tasks.sources.emails as emails
import crypta.graph.fuzzy.lib.tasks.sources.geo as geo
import crypta.graph.fuzzy.lib.tasks.sources.households as households
import crypta.graph.fuzzy.lib.tasks.sources.reqans as reqans
import crypta.graph.fuzzy.lib.tasks.sources.ssid as ssid
import crypta.graph.fuzzy.lib.tasks.sources.visitlog_logins as visitlog_logins
import crypta.graph.fuzzy.lib.tasks.sources.heuristic as heuristic


requires_tasks_dict = {
    "emails": emails.ProcessExtractedLoginsTask,
    "geo": geo.ExtractHomeWorkLog,
    "households": households.ImportHouseHoldsTask,
    "reqans": reqans.ImportReqansLogTask,
    "ssid": ssid.ImportSsidMobileMetrikaTask,
    "login": visitlog_logins.ExtractIntervalLogins,
    "heuristic": heuristic.ExtractExactPairs,
}

requires_tasks = requires_tasks_dict.values()


@yt_with_context
class MergeAllPairs(object):
    def __init__(self, source_types):
        self.source_types = source_types

    def __call__(self, keys, records, context):
        if keys[conf.Constants.YUID_LEFT] == keys[conf.Constants.YUID_RIGHT]:
            return
        yield create_table_switch(1)
        yield {
            conf.Constants.YUID_LEFT: keys[conf.Constants.YUID_LEFT],
            conf.Constants.YUID_RIGHT: keys[conf.Constants.YUID_RIGHT],
        }
        yield create_table_switch(0)
        for record in records:
            source_type = self.source_types[context.table_index]
            del record[conf.Constants.YUID_LEFT]
            del record[conf.Constants.YUID_RIGHT]
            yield {
                conf.Constants.YUID_LEFT: keys[conf.Constants.YUID_LEFT],
                conf.Constants.YUID_RIGHT: keys[conf.Constants.YUID_RIGHT],
                "attributes": record,
                "type": source_type,
            }


class CollectTask(BaseTask):
    date = DateParameter()

    def requires(self):
        return [task_class(date=self.date) for task_class in requires_tasks]

    @property
    def destination(self):
        return conf.Paths.COLLECTED_DATA

    @property
    def destination_schema(self):
        return conf.Paths.COLLECTED_DATA_SCHEMA

    def output(self):
        return [
            self.yt.targets.table_is_actual(self.destination, self.date.isoformat()),
            self.yt.targets.table_is_actual(conf.Paths.COLLECTED_YUID_PAIRS, self.date.isoformat()),
        ]

    def _run(self):
        self.yt.create_table_with_schema(
            self.destination, self.destination_schema, strict=True, recreate_if_exists=True
        )
        self.yt.create_table_with_schema(
            conf.Paths.COLLECTED_YUID_PAIRS,
            conf.Paths.COLLECTED_YUID_PAIRS_SCHEMA,
            strict=True,
            recreate_if_exists=True,
        )

        tasks = [task for task in self.requires() if hasattr(task, "source_type")]
        self.yt.run_reduce(
            MergeAllPairs([task.source_type for task in tasks]),
            [task.destination for task in tasks],
            [self.destination, conf.Paths.COLLECTED_YUID_PAIRS],
            reduce_by=[conf.Constants.YUID_LEFT, conf.Constants.YUID_RIGHT],
            spec={"data_size_per_job": 32 * yt_common.MB},
        )

        sort_operations = OperationsTracker()
        sort_operations.add(
            self.yt.run_sort(
                self.destination, sort_by=[conf.Constants.YUID_LEFT, conf.Constants.YUID_RIGHT], sync=False
            )
        )
        sort_operations.add(
            self.yt.run_sort(
                conf.Paths.COLLECTED_YUID_PAIRS,
                sort_by=[conf.Constants.YUID_LEFT, conf.Constants.YUID_RIGHT],
                sync=False,
            )
        )
        sort_operations.wait_all()

        self.yt.run_merge(self.destination, self.destination, spec={"combine_chunks": True})

        self.yt.set(self.destination + "/@generate_date", self.date.isoformat())

        self.yt.set(conf.Paths.COLLECTED_YUID_PAIRS + "/@generate_date", self.date.isoformat())
