import logging
import os
import yt.wrapper

from collections import Counter, defaultdict
from itertools import chain

from modadvert.bigmod.protos.interface.moderation_units_pb2 import EModerationMood
from modadvert.libs.constants import GiB
from modadvert.libs.utils.mappers import FilterMapper
from modadvert.libs.utils.py3 import initialize_logging
from modadvert.libs.utils.ytutils import yt_connect

import sandbox.projects.bigmod.b2b_helpers.common_comparator as common_comparator

BRANCHES = ['prestable', 'production']
TYPE_FIELDS = ['ObjectType', 'SubType']
ID_FIELDS = ['ObjectId', 'VersionId']
MAX_MOOD = max(EModerationMood.items())


def _unfold_by_rule_mapper(row):
    rules_moods = defaultdict(dict, {})
    moods_count = {}
    for branch in BRANCHES:
        for rule in row['{branch}_Markup'.format(branch=branch)]['Bigmod'].get('AutoModeratorRules', []):
            rules_moods[branch][rule['Name']] = rule['AppliedMood']
        del row['{branch}_Markup'.format(branch=branch)]
        moods_count[branch] = Counter(rules_moods[branch].values())

    for rule in set(chain.from_iterable(rules_moods[branch] for branch in BRANCHES)):
        for branch in BRANCHES:
            row['{branch}_mood'.format(branch=branch)] = rules_moods[branch].get(rule)
            row['{branch}_is_unique_mood'.format(branch=branch)] = moods_count[branch][row['{branch}_mood'.format(branch=branch)]] == 1  # Don't count `None` as unique mood
        row['rule'] = rule
        yield row


def _delete_markups_mapper(row):
    for branch in BRANCHES:
        del row['{branch}_Markup'.format(branch=branch)]
    yield row


def _unfold_by_flag_mapper(row):
    flags_by_branch = {}
    for branch in BRANCHES:
        flags_by_branch[branch] = list(chain.from_iterable(row['{branch}_Markup'.format(branch=branch)]['Flags'].values()))
        del row['{branch}_Markup'.format(branch=branch)]

    for flag in set(chain.from_iterable(flags_by_branch[branch] for branch in BRANCHES)):
        for branch in BRANCHES:
            row['{branch}_is_set'.format(branch=branch)] = (flag in flags_by_branch[branch])
        row['flag'] = flag
        yield row


def _combine_markups_reducer(reduce_key, rows):
    result_row = dict(reduce_key)

    rows_by_branch = {}
    for row in rows:
        row['mood'] = row['Markup']['ModerationMood']
        branch = row['branch']
        if (branch in rows_by_branch) and (rows_by_branch[branch]['mood'] != row['mood']):
            logging.warning('Different ModerationMoods in "{branch}" for key {reduce_key}'.format(branch=branch, reduce_key=reduce_key))
        rows_by_branch[branch] = row

    if len(rows_by_branch) == len(BRANCHES):
        for branch in BRANCHES:
            row = rows_by_branch[branch]
            for field in ['Markup', 'mood']:
                result_row['{branch}_{field}'.format(branch=branch, field=field)] = row[field]
            if not row:
                continue
            for field in ['ClientId', 'CampaignId']:
                value = row[field]
                if value != result_row.get(field, value):
                    raise ValueError('Inconsistent values in field {field} for key {reduce_key}: {first_value} and {second_value}'.format(
                        field=field,
                        reduce_key=reduce_key,
                        first_value=value,
                        second_value=result_row[field]
                    ))
                result_row[field] = value

        yield result_row


def _max_mood_reducer(reduce_key, rows):
    first_row = next(rows)
    result_row = dict(reduce_key)
    max_rows = {branch: first_row for branch in BRANCHES}

    for row in rows:
        all_branches_with_max_mood = True  # trying to reduce amount of consumed memory for fat reduce keys

        for branch in BRANCHES:
            max_rows[branch] = max(
                [max_rows[branch], row],
                key=lambda row: (
                    EModerationMood.Value(row['{branch}_mood'.format(branch=branch)] or 'EAM_GOOD'),
                    row.get('{branch}_is_unique_mood'.format(branch=branch), 0)
                )
            )
            all_branches_with_max_mood &= EModerationMood.Value(max_rows[branch]['{branch}_mood'.format(branch=branch)] or 'EAM_GOOD') == MAX_MOOD

        if all_branches_with_max_mood:
            break

    for branch in BRANCHES:
        for field, value in max_rows[branch].items():
            if field.startswith('{branch}_'.format(branch=branch)):
                result_row[field] = value

    yield result_row


class AutomoderatorComparison():

    def __init__(self, yt_cluster, yt_token, yt_src_dir, yt_memory_limit=6*GiB, yt_max_failed_job_count=5, logging_level='INFO'):
        self.yt_cluster = yt_cluster
        self.yt_token = yt_token
        self.yt_src_dir = yt_src_dir
        self.yt_memory_limit = yt_memory_limit
        self.yt_max_failed_job_count = yt_max_failed_job_count
        self.logging_level = logging_level
        initialize_logging(logging_level=self.logging_level)
        self.yt_client = yt_connect(proxy=self.yt_cluster, token=self.yt_token)
        self.yt_client.config['spec_defaults'].update({
            'mapper': {'memory_limit': self.yt_memory_limit},
            'reducer': {'memory_limit': self.yt_memory_limit},
            'max_failed_job_count': self.yt_max_failed_job_count
        })

    def _select_diff_markups(self, src_table, dst_dir):
        diff_table = os.path.join(dst_dir, 'diff_markups')
        logging.info('Selecting differing markups from {src_table} to {diff_table}'.format(src_table=src_table, diff_table=diff_table))
        with yt.wrapper.OperationsTracker(print_progress=True) as yt_tracker:
            yt_tracker.add(
                self.yt_client.run_map(
                    FilterMapper(lambda row: len(set(row['{branch}_mood'.format(branch=branch)] for branch in BRANCHES)) > 1),
                    src_table,
                    diff_table,
                    sync=False
                )
            )
            return diff_table

    def run(self):
        statistics_counters = [
            common_comparator.StatisticsCounter(
                name='flags',
                unfold_statistics_mapper=_unfold_by_flag_mapper,
                unfold_statistics_reducer=None,
                unfold_reduce_fields=None,
                count_reduce_fields=TYPE_FIELDS + ['flag'] + ['{branch}_is_set'.format(branch=branch) for branch in BRANCHES]
            ),
            common_comparator.StatisticsCounter(
                name='global/versions',
                unfold_statistics_mapper=_delete_markups_mapper,
                unfold_statistics_reducer=None,
                unfold_reduce_fields=None,
                count_reduce_fields=TYPE_FIELDS + ['{branch}_mood'.format(branch=branch) for branch in BRANCHES]
            ),
            common_comparator.StatisticsCounter(
                name='global/campaign',
                unfold_statistics_mapper=_delete_markups_mapper,
                unfold_statistics_reducer=_max_mood_reducer,
                unfold_reduce_fields=TYPE_FIELDS + ['CampaignId'],
                count_reduce_fields=TYPE_FIELDS + ['{branch}_mood'.format(branch=branch) for branch in BRANCHES]
            ),
            common_comparator.StatisticsCounter(
                name='global/client',
                unfold_statistics_mapper=_delete_markups_mapper,
                unfold_statistics_reducer=_max_mood_reducer,
                unfold_reduce_fields=TYPE_FIELDS + ['ClientId'],
                count_reduce_fields=TYPE_FIELDS + ['{branch}_mood'.format(branch=branch) for branch in BRANCHES]
            ),
            common_comparator.StatisticsCounter(
                name='rule/versions',
                unfold_statistics_mapper=_unfold_by_rule_mapper,
                unfold_statistics_reducer=None,
                unfold_reduce_fields=None,
                count_reduce_fields=TYPE_FIELDS + ['rule'] + ['{branch}_mood'.format(branch=branch) for branch in BRANCHES]
            ),
            common_comparator.StatisticsCounter(
                name='rule/campaign',
                unfold_statistics_mapper=_unfold_by_rule_mapper,
                unfold_statistics_reducer=_max_mood_reducer,
                unfold_reduce_fields=TYPE_FIELDS + ['rule', 'CampaignId'],
                count_reduce_fields=TYPE_FIELDS + ['rule'] + ['{branch}_mood'.format(branch=branch) for branch in BRANCHES]
            ),
            common_comparator.StatisticsCounter(
                name='rule/client',
                unfold_statistics_mapper=_unfold_by_rule_mapper,
                unfold_statistics_reducer=_max_mood_reducer,
                unfold_reduce_fields=TYPE_FIELDS + ['rule', 'ClientId'],
                count_reduce_fields=TYPE_FIELDS + ['rule'] + ['{branch}_mood'.format(branch=branch) for branch in BRANCHES]
            ),
        ]
        comparison = common_comparator.B2BComparison(BRANCHES, self.yt_src_dir, self.yt_client, TYPE_FIELDS + ID_FIELDS, _combine_markups_reducer, statistics_counters)
        comparison.run()

        self._select_diff_markups(os.path.join(self.yt_src_dir, 'comparison', 'statistics'), os.path.join(self.yt_src_dir, 'comparison'))
