# -*- coding: utf-8 -*-

from __future__ import print_function, absolute_import, division

import re
import operator

from nile.api.v1 import Record


class AtomDistributionAlertsCleaner(object):
    def __init__(self):
        super(AtomDistributionAlertsCleaner, self).__init__()

    def clean(self, then, now, config):
        tagged_then = then.project('key', 'signature', count_then='count', date_then='date')
        tagged_now = now.project('key', 'signature', count_now='count', date_now='date')
        abreast = tagged_then.join(tagged_now, by='key', type='inner')
        for stage_config in config['stages']:
            abreast = self.__clean_internal(abreast, stage_config)
        return self.__prettify(abreast)

    def __prettify(self, abreast):
        return abreast.map(PrettifyMapper())

    def __clean_internal(self, abreast, config):
        methods = {
            'drop': self.__clean_drop,
            'shift': self.__clean_shift,
            'white_lists': self.__clean_white_lists,
            'black_lists': self.__clean_black_list,
            'hierarchical': self.__clean_hierarchical
        }
        return methods[config['type']](abreast, config['params'])

    def __clean_drop(self, abreast, config):
        return abreast.map(DropCheckMapper(config))

    def __clean_shift(self, abreast, config):
        return abreast.map(ShiftCheckMapper(config))

    def __clean_white_lists(self, abreast, config):
        return abreast.map(RegExpMapper(config, include=True))

    def __clean_black_list(self, abreast, config):
        return abreast.map(RegExpMapper(config, include=False))

    def __clean_hierarchical(self, abreast, config):
        keys_lhs = abreast.project(key_lhs='key')
        keys_rhs = abreast.project(key_rhs='key')
        return (
            keys_lhs.join(keys_rhs, by=(), type='full')
            .map(HierarchicalCompareMapper(config))
            .groupby('key_lhs')
            .reduce(HierarchicalCompareReducer(config))
            .project(key='key_lhs')
            .join(abreast, type='left', by='key')
        )


class DropCheckMapper(object):
    def __init__(self, config):
        super(DropCheckMapper, self).__init__()
        self.config = config

    def __call__(self, records):
        for record in records:
            if self.__predicate(record):
                yield record

    def __predicate(self, record):
        old_value = record['count_then']
        new_value = record['count_now']
        return (
            old_value >= self.config['lower_bound'] and
            self.config['factor'] * new_value <= old_value
        )


class ShiftCheckMapper(object):
    def __init__(self, config):
        super(ShiftCheckMapper, self).__init__()
        self.config = config

    def __call__(self, records):
        for record in records:
            if self.__predicate(record):
                yield record

    def __predicate(self, record):
        old_value = record['count_then']
        new_value = record['count_now']
        return (
            (old_value >= self.config['drop_lower_bound'] and
             self.config['drop_factor'] * new_value <= old_value) or
            (old_value >= self.config['raise_lower_bound'] and
             self.config['raise_factor'] * old_value <= new_value)
        )


class RegExpMapper(object):
    def __init__(self, config, include=True):
        super(RegExpMapper, self).__init__()
        self.config = config
        self.include = include
        self.dimensions_predicates = {
            dimension: self.__build_dimension_predicate(regexp, self.include)
            for dimension, regexp in config.items()
        }

    def __call__(self, records):
        for record in records:
            if self.__predicate(record):
                yield record

    def __predicate(self, record):
        dimensions = map(operator.itemgetter(0), record['signature'])
        segment = record['key'].split('\t')[1:]
        for dimension, value in zip(dimensions, segment):
            if dimension in self.dimensions_predicates:
                predicate = self.dimensions_predicates[dimension]
                if value != '_total_' and not predicate(value):
                    return False
        return True

    def __build_dimension_predicate(self, regexp, include):
        re_parser = re.compile(regexp)

        def include_predicate(text):
            return text is not None and re_parser.match(text) is not None

        def exclude_predicate(text):
            return text is not None and re_parser.match(text) is None

        return include_predicate if include else exclude_predicate


class HierarchicalCompareMapper(object):
    def __init__(self, config):
        super(HierarchicalCompareMapper, self).__init__()
        self.config = config

    def __call__(self, records):
        for record in records:
            segment_lhs = record['key_lhs'].split('\t')[1:]
            segment_rhs = record['key_rhs'].split('\t')[1:]
            yield Record(record, dominated=self.__is_dominated(segment_lhs, segment_rhs))

    def __is_dominated(self, segment_lhs, segment_rhs):
        length = len(segment_lhs)
        for i in range(length):
            if segment_lhs[i] != segment_rhs[i] and segment_rhs[i] != '_total_':
                return 0
        return 1


class HierarchicalCompareReducer(object):
    def __init__(self, config):
        super(HierarchicalCompareReducer, self).__init__()
        self.config = config

    def __call__(self, groups):
        for key, records in groups:
            dominations = sum(record['dominated'] for record in records)
            if dominations <= 1:
                yield Record(key)


class PrettifyMapper(object):
    def __call__(self, records):
        super(PrettifyMapper, self).__init__()
        for record in records:
            event, segment = record['key'].split('\t', 1)
            all_dimensions = [dimension for dimension, flag in record['signature']]
            dimensions = [dimension for dimension, flag in record['signature'] if flag]
            yield Record(
                date_then=record['date_then'],
                date_now=record['date_now'],
                count_then=record['count_then'],
                count_now=record['count_now'],
                event=event,
                dimensions=dimensions,
                slice=tuple(zip(all_dimensions, segment.split('\t')))
            )
