from crypta.graph.fuzzy.lib.luiger import BaseTask, DateParameter
from yt.wrapper import with_context as yt_with_context
from crypta.graph.fuzzy.lib.common import cached_property
import crypta.graph.fuzzy.lib.config as conf
from crypta.graph.fuzzy.lib.tasks.classifier.counters import ComputeCounters
import math
from crypta.graph.fuzzy.lib.tasks.classifier.collect import requires_tasks_dict
from collections import defaultdict, namedtuple
import logging

logger = logging.getLogger(__name__)


@yt_with_context
class FilterYuidPairsReducer(object):
    socdem_source_type = "ENRICHED_DATA"

    def __init__(self, threshold):
        self.threshold = threshold

    def __call__(self, keys, records, context):
        record_attributes = defaultdict(list)
        summary_coefficient = 0.0
        for record in records:
            record_attributes[record["type"]].append(record.get("attributes", record))

        for source_type, records in record_attributes.iteritems():
            max_coefficient = 0.0
            count_of_entries = len(records)
            for record in records:
                coefficient = float(record.get("coefficient", 1.0))
                if max_coefficient < coefficient:
                    max_coefficient = coefficient
            if source_type == self.socdem_source_type:
                max_coefficient -= 1
                gender_difference = abs(
                    float(records[0].get("socdem_left_gender", {"f": 2.0})["f"])
                    - float(records[0].get("socdem_right_gender", {"f": -1.0})["f"])
                )
                if gender_difference < 0.000001:  # same socdem
                    continue
                gender_difference = max(0.1, gender_difference)
                if gender_difference < 0.5:
                    max_coefficient += 0.01 / gender_difference
                else:
                    max_coefficient -= gender_difference * 2
                top_left_age = sorted(
                    records[0]["socdem_left_age_segments"].iterkeys(), key=lambda x: x[1], reverse=True
                )[0]
                top_right_age = sorted(
                    records[0]["socdem_right_age_segments"].iterkeys(), key=lambda x: x[1], reverse=True
                )[0]
                if top_left_age == top_right_age:
                    max_coefficient += 0.25
            summary_coefficient += max_coefficient + (count_of_entries / 10.0)

        if summary_coefficient >= self.threshold:
            yield {
                conf.Constants.YUID_LEFT: keys[conf.Constants.YUID_LEFT],
                conf.Constants.YUID_RIGHT: keys[conf.Constants.YUID_RIGHT],
                "score": float(summary_coefficient),
            }
            return


class ProbEstimator(object):
    def __init__(self, reqdest, heuristic_path):
        self.__reqdest = reqdest
        self.__heuristic_path = heuristic_path

    def __get_row_count(self, path):
        return self.yt.get(path + "/@row_count")

    def __get_pairs_count(self, path):
        row_count = self.__get_row_count(path)
        return (row_count - 1) * row_count


sources = requires_tasks_dict.keys()
KeyType = namedtuple("KeyType", sources)


def make_key(row):
    params = {key: row[key] if key in row else None for key in sources}
    return KeyType(**params)


class BayesClassifyTask(BaseTask):

    date = DateParameter()
    MIN_SAMPLE = 1000

    @cached_property
    def __counters_task(self):
        return ComputeCounters(date=self.date)

    def requires(self):
        yield self.__counters_task

    @property
    def destination(self):
        return conf.Paths.FUZZY_RESULT

    @property
    def destination_schema(self):
        return conf.Paths.FUZZY_RESULT_SCHEMA

    def output(self):
        yield self.yt.targets.table_is_actual(self.destination, self.date.isoformat())

    @cached_property
    def __reqdest(self):
        return {
            key: task(date=self.date).destination
            for key, task in requires_tasks_dict.iteritems()
            if key != "heuristic"
        }

    @cached_property
    def __heuristic_path(self):
        heuristic_task_class = requires_tasks_dict["heuristic"]
        heuristic_task = heuristic_task_class(date=self.date)
        return heuristic_task.destination

    def __get_row_count(self, path):
        return self.yt.get(path + "/@row_count")

    def __get_pairs_count(self, path):
        row_count = self.__get_row_count(path)
        return (row_count - 1) * row_count

    @cached_property
    def __global_pcts(self):
        total = self.__get_pairs_count(conf.Paths.YUID_WITH_ALL)
        pcts = dict()
        global_counters_path = self.__counters_task.global_counters
        for row in self.yt.read_table(global_counters_path):
            cnt = row["cnt"]
            if cnt > self.MIN_SAMPLE:
                key = make_key(row)
                pcts[key] = math.log(cnt) - math.log(total)
        return pcts

    @cached_property
    def __conditional_pcts(self):
        total = self.__get_row_count(self.__heuristic_path)
        pcts = dict()
        conditional_counters_path = self.__counters_task.conditional_counters
        for row in self.yt.read_table(conditional_counters_path):
            cnt = row["cnt"]
            if cnt > self.MIN_SAMPLE:
                key = make_key(row)
                pcts[key] = math.log(row["cnt"]) - math.log(total)
        return pcts

    @staticmethod
    def compute_score(row, conditional_pcts, global_pcts):
        score = 0.0
        key = make_key(row)
        if key in conditional_pcts and key in global_pcts:
            score = conditional_pcts[key] - global_pcts[key]
        return score

    def ___pair_mapper(self):
        YUID_LEFT = conf.Constants.YUID_LEFT
        YUID_RIGHT = conf.Constants.YUID_RIGHT
        conditional_pcts = self.__conditional_pcts
        global_pcts = self.__global_pcts
        compute_score = self.compute_score

        def mapper(row):
            yield {
                YUID_LEFT: row[YUID_LEFT],
                YUID_RIGHT: row[YUID_RIGHT],
                "score": compute_score(row, conditional_pcts, global_pcts),
            }

        return mapper

    def __condition_mapper(self):
        conditional_pcts = self.__conditional_pcts
        global_pcts = self.__global_pcts
        compute_score = self.compute_score

        def mapper(row):
            output = dict()
            output.update(row)
            output["score"] = compute_score(row, conditional_pcts, global_pcts)
            yield output

        return mapper

    def _run(self):
        if self.yt.exists(self.destination):
            self.yt.remove(self.destination)
        self.yt.create_table_with_schema(
            self.destination, self.destination_schema, strict=True, recreate_if_exists=True
        )
        logger.info("Table created")

        logger.info(self.__conditional_pcts)
        logger.info(self.__global_pcts)

        logger.info("Calculate condition scores")
        self.yt.run_map(self.__condition_mapper(), conf.Paths.DEBUG_SCORES, conf.Paths.DEBUG_SCORES)

        logger.info("Run user scoring")
        self.yt.run_map(self.___pair_mapper(), conf.Paths.FUZZY_CANDIDATES, self.destination)
        self.yt.run_merge(self.destination, self.destination, spec={"combine_chunks": True})
        self.yt.set(self.destination + "/@generate_date", self.date.isoformat())


ClassifyTask = BayesClassifyTask
