from collections import defaultdict

import collections

from crypta.graph.v1.python.utils import mr_utils


class ColumnAggregator(object):
    def accept_rec(self, rec):
        raise NotImplementedError

    def get_aggregated_value(self):
        raise NotImplementedError


def is_collection(obj):
    if isinstance(obj, basestring):
        return False
    return isinstance(obj, collections.Sequence)


class SetUnionAggregator(ColumnAggregator):
    def __init__(self, column, oom_limit=None):
        self.column = column
        self.oom_limit = oom_limit
        self.set_union = set()

    def accept_rec(self, rec):
        rec_value = rec.get(self.column)

        if rec_value:
            if self.oom_limit and len(self.set_union) > self.oom_limit:
                raise mr_utils.OomLimitException(self.oom_limit)

            if is_collection(rec_value):
                self.set_union.update(rec_value)
            else:
                self.set_union.add(rec_value)

    def get_aggregated_value(self):
        return list(self.set_union)


class NewKeyValueMappingAggregator(ColumnAggregator):
    def __init__(self, key_column, value_column, oom_limit=None):
        self.key_column = key_column
        self.value_column = value_column
        self.oom_limit = oom_limit
        self.kv_mapping = dict()

    def accept_rec(self, rec):
        key = rec.get(self.key_column)
        value = rec.get(self.value_column)

        if key and value:
            if self.oom_limit and len(self.kv_mapping) > self.oom_limit:
                raise mr_utils.OomLimitException(self.oom_limit)

            self.kv_mapping[key] = value

    def get_aggregated_value(self):
        return self.kv_mapping


class MergeDictAggregator(ColumnAggregator):
    def __init__(self, dict_column, oom_limit=None):
        self.dict_column = dict_column
        self.oom_limit = oom_limit
        self.merged_dict = dict()

    def accept_rec(self, rec):
        dict_to_add = rec.get(self.dict_column)
        if dict_to_add:
            if self.oom_limit and len(self.merged_dict) > self.oom_limit:
                raise mr_utils.OomLimitException(self.oom_limit)

            self.merged_dict.update(dict_to_add)

    def get_aggregated_value(self):
        return self.merged_dict


class FirstRecAggregator(ColumnAggregator):
    def __init__(self):
        self.first_rec = None

    def accept_rec(self, rec):
        if not self.first_rec:
            self.first_rec = rec

    def get_aggregated_value(self):
        return self.first_rec


class FirstValueAggregator(ColumnAggregator):
    def __init__(self, column):
        self.column = column
        self.first_value = None

    def accept_rec(self, rec):
        rec_value = rec.get(self.column)
        if rec_value and not self.first_value:
            self.first_value = rec[self.column]

    def get_aggregated_value(self):
        return self.first_value


class LastValueAggregator(ColumnAggregator):
    def __init__(self, column):
        self.column = column
        self.last_value = None

    def accept_rec(self, rec):
        rec_value = rec.get(self.column)
        if rec_value:
            self.last_value = rec[self.column]

    def get_aggregated_value(self):
        return self.last_value


class MergeHitsDictsAggregator(ColumnAggregator):
    def __init__(self, column):
        self.column = column
        self.hits_dict = defaultdict(int)

    def accept_rec(self, rec):
        rec_value = rec.get(self.column)
        if rec_value:
            # TODO: counter?
            if isinstance(rec_value, dict):
                for k, hits in rec_value.iteritems():
                    self.hits_dict[str(k)] += hits  # for some stupid reason yson keys should be strings
            else:
                self.hits_dict[str(rec_value)] += 1  # for some stupid reason yson keys should be strings

    def get_aggregated_value(self):
        return self.hits_dict
