# coding=utf-8
import itertools
import logging
import time
from collections import defaultdict
from functools import partial

import luigi
import networkx as nx
import yt.wrapper as yt

import rtcconf.config
from lib.contrib import nx_community
from lib.luigi import yt_luigi
from matching.human_matching import crypta_id_generator
from matching.human_matching import graph_vertices_base
from matching.human_matching.stats import graph_quality_metrics
from matching.pairs import graph_pairs
from rtcconf import config
from utils import mr_utils as mr
from utils import utils

GB = 1024 * 1024 * 1024


def edges_to_vertices_format(components):

    def short_to_long_id_type(short_id_type):
        if short_id_type == 'y':
            return 'yuid'
        else:
            return 'deviceid'

    def id_values(rec):
        if 'id_value' in rec:
            return {rec['id_type']: [rec['id_value']]}
        else:
            return dict()

    vertices_by_components = defaultdict(dict)
    for component in components:
        for edge in component.edges:
            edge_rec = edge.orig_rec
            edge_source = edge_rec['pair_source']

            id1_type, id2_type = edge_rec['pair_type'].split('_')
            vertex1 = {'key': edge.id1, 'id_type': short_to_long_id_type(id1_type),
                       'ua_profile': edge_rec.get('id1_ua'),
                       'browser': edge_rec.get('id1_browser'),
                       'sex': edge_rec.get('id1_sex'),
                       'region': edge_rec.get('id1_region'),
                       'id_values': id_values(edge_rec)}
            vertex2 = {'key': edge.id2, 'id_type': short_to_long_id_type(id2_type),
                       'ua_profile': edge_rec.get('id2_ua'),
                       'browser': edge_rec.get('id2_browser'),
                       'sex': edge_rec.get('id2_sex'),
                       'region': edge_rec.get('id2_region'),
                       'id_values': id_values(edge_rec)}
            vertices_by_components[component.component_id][(edge.id1, edge_source)] = vertex1
            vertices_by_components[component.component_id][(edge.id2, edge_source)] = vertex2

    # dict of vertices -> list
    vertices_by_components = {c: vertices_dict.values() for c, vertices_dict in vertices_by_components.iteritems()}
    return vertices_by_components


def get_sex_prob(sex_value):
    if sex_value:
        sex_split = sex_value.split(',')
        if not len(sex_split) == 2:
            return None
        male_prob = float(sex_split[0])
        return male_prob


def weight_single_source_d_y_edge_rec(edge_rec):
    source_type = edge_rec['source_type']
    if source_type in config.DEVID_PAIR_TYPES_PERFECT_DICT:
        return config.DEVID_PAIR_TYPES_PERFECT_DICT[source_type].base_weight
    else:
        # fuzzy and no match
        return 0.5


def weight_single_source_y_y_edge_rec(edge_rec, clustering_config):
    id_type = edge_rec['id_type']
    source_type = edge_rec.get('source_type')

    if id_type in config.YUID_PAIR_TYPES_DICT:
        edge_config = config.YUID_PAIR_TYPES_DICT[id_type]

        if source_type and source_type in edge_config.custom_weights:
            edge_weight = edge_config.custom_weights[source_type]
        else:
            edge_weight = edge_config.base_weight

    elif id_type == rtcconf.config.PAIR_SRC_FUZZY:
        edge_weight = float(edge_rec['id_value'])
    else:
        edge_weight = 0.5  # unknown source

    yuid1_region = edge_rec.get('id1_region')
    yuid2_region = edge_rec.get('id2_region')
    if clustering_config.use_region and \
            yuid1_region and yuid2_region and yuid1_region != yuid2_region:
        edge_weight /= 4.0

    yuid1_sex = get_sex_prob(edge_rec.get('id1_sex'))
    yuid2_sex = get_sex_prob(edge_rec.get('id2_sex'))
    if clustering_config.use_sex and yuid1_sex and yuid2_sex:
        # if eq, should be 1, if max diff 0<->1, should be 1
        sex_similarity = 1 - abs(yuid1_sex - yuid2_sex)
        if sex_similarity < 0.7:
            edge_weight *= sex_similarity

    yuid1_id_dates = edge_rec.get('id1_dates')
    yuid2_id_dates = edge_rec.get('id2_dates')
    if clustering_config.use_values and yuid1_id_dates and yuid2_id_dates:
        id_value = edge_rec['id_value']

        # example:
        # multi-edge contains several edges of specific id type (e.g. emails)
        # each edge belongs to single email, e.g. email1 of (email1, email2)
        # compare dates of this edge email (e.g. email1)
        pair_dates = set()
        pair_dates.update(yuid1_id_dates[id_value], yuid2_id_dates[id_value])

        # with dates of all multi-edge emails (email1, email2)
        all_dates = set()
        all_dates.update(utils.flatten(dates.keys() for dates in yuid1_id_dates.values()))
        all_dates.update(utils.flatten(dates.keys() for dates in yuid2_id_dates.values()))

        values_ratio = len(pair_dates) / float(len(all_dates))
        edge_weight *= values_ratio

    return edge_weight


def weight_multi_edge_old(multi_edge, _, clustering_config):
    # TODO: remove when new weighting is enabled in prod
    def get_source_group(rec):
        pair_type = rec['pair_type']
        if pair_type == 'y_y':
            pair_source = (rec['id_type'], rec['source_type'])
            group = config.yuid_sources_groups.get(pair_source)
        else:
            pair_source = rec['source_type']
            group = config.d_y_source_groups.get(pair_source)

        if group:
            return group
        else:
            return pair_source

    groups_counts = defaultdict(int)

    # firstly count source repeats in each source group
    for edge in multi_edge:
        # count edges with multiple sources several times
        for rec in graph_pairs.expands_source_types(edge.orig_rec):
            groups_counts[get_source_group(rec)] += 1 # each edge can belong to the group of duplicated sources

    final_weight = 0.0
    for edge in multi_edge:
        edge.weight = 0
        # count edges with multiple sources several times
        for rec in graph_pairs.expands_source_types(edge.orig_rec):
            if edge.pair_type == 'y_y':
                single_source_weight = weight_single_source_y_y_edge_rec(rec, clustering_config)
            else:
                single_source_weight = weight_single_source_d_y_edge_rec(rec)


            # reduce edge weight by the count of source duplicates
            single_source_weight /= float(groups_counts[get_source_group(rec)])

            edge.weight += single_source_weight
            final_weight += single_source_weight

    return multi_edge, final_weight


def weight_multi_edge_new(multi_edge, pair_type, clustering_config):
    # I don't want to explain this code because I hate it. CRYPTAIS-1336

    if pair_type == 'y_y':
        # y_y edges are in fact inderect y_someid_y edges, thus we calculate left and right source separately
        left_yuid_source_groups = defaultdict(int)
        right_yuid_source_groups = defaultdict(int)

        # firstly, need to count source repeats in each source group
        for edge in multi_edge:
            rec = edge.orig_rec
            # only yuid-yuid can be cross source now
            for s in set(rec['yuid1_sources']):
                st = (rec['id_type'], s)
                source_group = config.yuid_sources_groups.get(st) or st
                left_yuid_source_groups[source_group] += 1

            for s in set(rec['yuid2_sources']):
                st = (rec['id_type'], s)
                source_group = config.yuid_sources_groups.get(st) or st
                right_yuid_source_groups[source_group] += 1

        left_multi_weight = 0
        right_multi_weight = 0
        for edge in multi_edge:
            rec = edge.orig_rec

            left_weight = 0
            for s in set(rec['yuid1_sources']):
                single_source_weight = weight_single_source_y_y_edge_rec(rec, clustering_config)
                # reduce edge weight by the count of source duplicates
                st = (rec['id_type'], s)
                source_group = config.yuid_sources_groups.get(st) or st
                single_source_weight /= float(left_yuid_source_groups[source_group])
                left_weight += single_source_weight
                left_multi_weight += single_source_weight

            right_weight = 0
            for s in set(rec['yuid2_sources']):
                single_source_weight = weight_single_source_y_y_edge_rec(rec, clustering_config)
                # reduce edge weight by the count of source duplicates
                st = (rec['id_type'], s)
                source_group = config.yuid_sources_groups.get(st) or st
                single_source_weight /= float(right_yuid_source_groups[source_group])
                right_weight += single_source_weight
                right_multi_weight += single_source_weight

            edge.weight = min(left_weight, right_weight)  # for debug only

        return multi_edge, min(left_multi_weight, right_multi_weight)


    else:
        # d_y edge are direct edges, this count as a whole
        d_y_yuid_source_groups = defaultdict(int)
        for edge in multi_edge:
            st = edge.orig_rec['source_type']
            source_group = config.d_y_source_groups.get(st) or st
            d_y_yuid_source_groups[source_group] += 1

        d_y_multi_weight = 0
        for edge in multi_edge:
            rec = edge.orig_rec
            edge.weight = weight_single_source_d_y_edge_rec(rec)
            st = edge.orig_rec['source_type']
            source_group = config.d_y_source_groups.get(st) or st
            edge.weight /= float(d_y_yuid_source_groups[source_group])
            d_y_multi_weight += edge.weight

        return multi_edge, d_y_multi_weight


USED_COLUMNS = ['crypta_id', 'crypta_id_size',
                'id1', 'id2', 'key', 'ua_profile', 'browser',
                'id1_sex', 'id2_sex',
                'id1_region', 'id2_region',
                'id1_dates', 'id2_dates',
                'pair_source', 'pair_type', 'source_type',  'id_type', 'id_value',
                'yuid1_sources', 'yuid2_sources']

class Edge(object):

    def __init__(self, pair_rec):
        """
        :type pair_rec: dict
        """
        self.orig_rec = pair_rec
        self.id1 = pair_rec['id1']
        self.id2 = pair_rec['id2']
        self.pair_type = pair_rec['pair_type']
        self.pair_source = pair_rec['pair_source']
        self.weight = None

    def __key(self):
        return self.id1, self.id2, self.pair_source

    def __eq__(self, other):
        return self.__key() == other.__key() if other is not None else False

    def __hash__(self):
        return hash(self.__key())

    def __str__(self):
        return str(self.orig_rec)

    def __repr__(self):
        return str(self.orig_rec)

    def get_all_yuids(self):
        if self.pair_type == 'd_y':
            return [self.id2]
        elif self.pair_type == 'y_y':
            return [self.id1, self.id2]
        else:
            raise Exception('Unsupported pair_type %s' % self.pair_type)


class Component(object):
    def __init__(self, edges, vertices_count, component_id=None):
        """
        :type edges: list of Edge
        :type vertices_count: int
        """
        self.edges = edges
        self.vertices_count = vertices_count
        self.component_id = component_id  # can't determine the final count of components before full hierarchical split


    def can_split(self, clustering_config):
        # definitely need to split large components
        can_split_large_size = self.vertices_count > clustering_config.max_size_component

        # can split if there are too many identifiers of certain type for a single component
        id_values_count_stats = self.get_edge_id_values_stats()
        can_split_by_human_limit = False
        if clustering_config.cluster_multi_values_over_human_limit:
            for edge_type in config.YUID_PAIR_TYPES_EXACT:
                id_values_count = id_values_count_stats.get(edge_type.id_type)
                if id_values_count and edge_type.human_limit and id_values_count > edge_type.human_limit:
                    can_split_by_human_limit = True

        return can_split_large_size or can_split_by_human_limit


    def split(self, vertices_to_component_map):
        """
        Splits components edges to several components according to map
        :param vertices_to_component_map: vertex_id -> component_id mapping
        :return: tuple (list of new components, list of removed edges)
        """
        edges_per_component = defaultdict(list)
        vertices_per_component = defaultdict(set)  # to count_component sizes
        removed_edges = list()
        for edge in self.edges:
            #  vertices of edge ain't in the same component
            c1 = vertices_to_component_map.get(edge.id1)
            c2 = vertices_to_component_map.get(edge.id2)
            if c1 == c2:
                c = c1  # == c2
                edges_per_component[c].append(edge)
                vertices_per_component[c].add(edge.id1)
                vertices_per_component[c].add(edge.id2)
            else:
                removed_edges.append(edge)


        new_components = [Component(component_edges, len(vertices_per_component[c_id]))
                          for c_id, component_edges in edges_per_component.iteritems()]
        return new_components, removed_edges

    def add_edges(self, edges):
        self.edges.extend(edges)

    def get_multi_edges(self):
        for (id1, id2, pair_type), multi_edge in itertools.groupby(sorted(self.edges),
                                                        key=lambda edge: (edge.id1, edge.id2, edge.pair_type)):
            yield (id1, id2, pair_type), list(multi_edge)


    def remove_vertex(self, id_key, cut_limit):
        """
        Removes single vertex from component by removing all belonging edges
        :param id_key: vertex id
        :param cut_limit: remove only if remaining number of edges >= cut_limit
        :return: number of removed edges
        """
        edges_after_cut = [e for e in self.edges if e.id1 != id_key and e.id2 != id_key]
        if len(edges_after_cut) >= cut_limit:
            removed_edges_count = len(self.edges) - len(edges_after_cut)
            self.edges = edges_after_cut
            return removed_edges_count
        else:
            return 0

    def elect_new_crypta_id(self):
        yuids_to_ts = dict()
        for edge in self.edges:
            for yuid in edge.get_all_yuids():
                try:
                    ts = int(yuid[-10:])
                    yuids_to_ts[yuid] = ts
                except ValueError:
                    pass

        try:
            min_ts_yuid, _ = min(yuids_to_ts.iteritems(), key=lambda x: x[1])
        except ValueError:
            raise Exception('%s' % self.edges)

        return crypta_id_generator.yuid_to_crypta_id(min_ts_yuid)

    def get_all_vertices_ids(self):
        yuids_and_devids = set()
        for edge in self.edges:
            yuids_and_devids.add(edge.id1)
            yuids_and_devids.add(edge.id2)

        return yuids_and_devids

    def get_edge_id_values_stats(self):
        id_values_per_type = defaultdict(set)
        # count of all id_values of edges
        for edge in self.edges:
            id_type = edge.orig_rec.get('id_type')
            id_value = edge.orig_rec.get('id_value')
            if id_value and id_type in config.YUID_PAIR_TYPES_DICT:
                id_values_per_type[id_type].add(id_value)

        return {id_type: len(id_values) for id_type, id_values in id_values_per_type.iteritems()}


def join_components_to_vertices(key, recs):
    vertices_recs, new_crypta_id_recs = mr.split_left_right(recs, oom_check=False)
    if new_crypta_id_recs and vertices_recs:
        vertices_rec = vertices_recs[0]
        new_crypta_id_rec = new_crypta_id_recs[0]

        old_crypta_id = vertices_rec['crypta_id']
        if old_crypta_id == new_crypta_id_rec['old_crypta_id']:
            if 'crypta_id_history' not in vertices_rec:
                vertices_rec['crypta_id_history'] = dict()
            vertices_rec['crypta_id_history']['before_clustering'] = vertices_rec['crypta_id']
            vertices_rec['crypta_id'] = new_crypta_id_rec['new_crypta_id']
            vertices_rec['crypta_id_size'] = new_crypta_id_rec['new_crypta_id_size']
            vertices_rec['component'] = new_crypta_id_rec['component']
            yield vertices_rec

        else:
            yield {'error': 'Crypta id %s != %s' % (old_crypta_id, new_crypta_id_rec['old_crypta_id']),
                   '@table_index': 1}
            return
    elif vertices_recs:
        for vr in vertices_recs:
            yield vr
    else:
        err_rec = new_crypta_id_recs[0]
        err_rec['@table_index'] = 2
        yield err_rec

        # raise Exception('Found non-existing vertex after clustering: %s' % new_crypta_id_recs[0])


def reduce_split_crypta_id_louvain(crypta_key, edge_recs,
                                   clustering_config,
                                   local_mode=False):

    def with_table_index(rec, table_index):
        """
        When running in local mode, yt doesn't support @table_index notation.
        Thus need to spread among tables manually
        """
        if local_mode:
            rec['tmp_table_index'] = table_index
        else:
            rec['@table_index'] = table_index
        return rec


    def yield_no_split(edges):
        for edge in edges:
            if edge.weight:
                edge.orig_rec['weight'] = edge.weight
            edge.orig_rec['component'] = '0'
            edge.orig_rec['components_count'] = 1
            edge.orig_rec['old_crypta_id'] = edge.orig_rec['crypta_id']
            edge.orig_rec['old_crypta_id_size'] = edge.orig_rec['crypta_id_size']
            yield edge.orig_rec


    def debug_log(msg):
        # can't log at yt cluster
        if local_mode:
            logging.info(msg)


    def cluster_component(main_component):
        debug_log('Calculating edge weights...')
        # Firstly calculate all edges weights for later performance
        multi_edges_weights = defaultdict(int)
        multi_edges_is_strong = dict()
        pair_types = [devid_pair_type.source_type for devid_pair_type in config.DEVID_PAIR_TYPES_PERFECT]
        for (id1, id2, pair_type), multi_edge in main_component.get_multi_edges():
            if clustering_config.new_multi_source_approach:
                _, multi_edge_weight = weight_multi_edge_new(multi_edge, pair_type, clustering_config)
            else:
                _, multi_edge_weight = weight_multi_edge_old(multi_edge, pair_type, clustering_config)

            # here all single weights are calculated in edge.weight for every edge
            multi_edges_weights[(id1, id2)] = multi_edge_weight

            # set 'strong' (i.e. unbreakable) field; makes sence only for perfect 'd_y' indevice pairs
            multi_edges_is_strong[(id1, id2)] = (pair_type == 'd_y' and
                                        any((edge.orig_rec['source_type'] in pair_types and
                                            config.DEVID_PAIR_TYPES_PERFECT[pair_types.index(edge.orig_rec['source_type'])].strong)
                                                for edge in multi_edge))

        debug_log('Calculated edge weights for %d multi-edges...' % len(multi_edges_weights))
        components = [main_component]
        removed_edges = set()
        # hierarchically split large components with max 4 iteration in hierarchy
        for iter_idx in range(4):
            debug_log('Iter %s, components to split: %s' % (iter_idx, len(components)))
            next_iter_components = list()
            for c_idx, component in enumerate(components):
                if component.can_split(clustering_config):
                    # split component if it's large enough...
                    G = nx.Graph()

                    debug_log('Iter %s: splitting component %s with size %d (%d edges)' %
                              (iter_idx, c_idx, component.vertices_count, len(component.edges)))

                    if clustering_config.collapse_multi_edges:
                        for (id1, id2), multi_edge in component.get_multi_edges():
                            G.add_edge(id1, id2, weight=multi_edges_weights[(id1, id2)],
                                       strong=multi_edges_is_strong[(id1, id2)])
                    else:
                        for edge in component.edges:
                            G.add_edge(edge.id1, edge.id2, weight=edge.weight,
                                       strong=multi_edges_is_strong[(edge.id1, edge.id2)])

                    start = time.time()
                    component_partition = nx_community.best_partition(G)

                    component_split, iter_removed_edges = component.split(component_partition)
                    next_iter_components.extend(component_split)
                    removed_edges.update(iter_removed_edges)

                    end = time.time()
                    debug_log('Clustered into %d components (removed %d edges) [%ss]' %
                              (len(component_split),
                               len(iter_removed_edges),
                               end - start))

                else:
                    # ... or keep it as is every next iteration
                    next_iter_components.append(component)  # keep as is

            components = next_iter_components
            debug_log('Current components count = %d' % len(components))

        # can't determine the final count of components before full hierarchical split
        for c_idx, component in enumerate(components):
            component.component_id = c_idx

        return components, removed_edges

    old_crypta_id = crypta_key['crypta_id']
    old_crypta_id_size = crypta_key['crypta_id_size']

    if old_crypta_id_size <= clustering_config.min_split_threshold:
        for r in yield_no_split(Edge(edge_rec) for edge_rec in edge_recs):
            yield with_table_index(r, 0)
        return

    # original crypta_id itself
    top_component = Component([Edge(edge) for edge in edge_recs], old_crypta_id_size)
    debug_log('Created top component with %d edges for %s' % (len(top_component.edges), old_crypta_id))
    top_component.component_id = 0

    vertices_by_components = edges_to_vertices_format([top_component])
    qm_origin = graph_quality_metrics.clustering_quality_metrics(vertices_by_components)
    qm_origin['crypta_id'] = old_crypta_id
    qm_origin['edges_count'] = len(top_component.edges)
    yield with_table_index(qm_origin, 3)

    debug_log('Calculated penalties before split for %s' % old_crypta_id)

    split_components, removed_edges = cluster_component(top_component)
    components_count = len(split_components)
    debug_log('crypta id %s split into %d components' % (old_crypta_id, len(split_components)))

    if clustering_config.reassign_corner_values:
        split_components, reassing_stats = reassign_corner_values(split_components, removed_edges)
        for id_key, old_c, new_c, removed_count in reassing_stats:
            stat_rec = dict(crypta_key)
            stat_rec.update({'id': id_key,
                             'old_c': old_c.component_id,
                             'new_c': new_c.component_id,
                             'removed_count': removed_count})
            yield with_table_index(stat_rec, 5)

    vertices_by_components = edges_to_vertices_format(split_components)
    qm_split = graph_quality_metrics.clustering_quality_metrics(vertices_by_components)
    qm_split['crypta_id'] = old_crypta_id
    qm_split['removed_pairs_count'] = len(removed_edges)
    qm_split['pairs_count'] = len(top_component.edges)
    qm_split['removed_percent'] = len(removed_edges) / float(len(top_component.edges))
    yield with_table_index(qm_split, 4)

    debug_log('Calculated penalties for split for %s' % old_crypta_id)

    # if we don't reduce overmatching or unsplice to much, let's keep the whole crypta_id
    if qm_split['overmatching_penalty'] > qm_origin['overmatching_penalty']:
        for r in yield_no_split(top_component.edges):
            yield with_table_index(r, 0)
            stat_rec = dict(crypta_key)
            stat_rec.update({
                'orig_overmatching_penalty': qm_origin['overmatching_penalty'],
                'split_overmatching_penalty': qm_split['overmatching_penalty'],
                'oversplicing_penalty': qm_split['oversplicing_penalty']
            })
            yield with_table_index(stat_rec, 6)
        return

    # for every new component yield vertices and edges
    for component in split_components:
        new_crypta_id = component.elect_new_crypta_id()
        component_id = component.component_id
        debug_log('yielding component [%s, %s]' % (old_crypta_id, component_id))

        for edge in component.edges:
            edge_rec = edge.orig_rec

            edge_rec['weight'] = edge.weight
            edge_rec['component'] = str(component_id)
            edge_rec['components_count'] = components_count
            edge_rec['crypta_id'] = new_crypta_id
            edge_rec['crypta_id_size'] = component.vertices_count
            edge_rec['old_crypta_id'] = old_crypta_id
            edge_rec['old_crypta_id_size'] = old_crypta_id_size
            yield with_table_index(edge_rec, 0)  # remaining pairs

        for vertex_id in component.get_all_vertices_ids():
            # need this to map old-style vertices to new crypta_ids
            new_crypta_id_rec = {'key': vertex_id,
                                 'old_crypta_id': old_crypta_id,
                                 'old_crypta_id_size': old_crypta_id_size,
                                 'new_crypta_id': new_crypta_id,
                                 'new_crypta_id_size': component.vertices_count,
                                 'component': str(component_id),
                                 'components_count': components_count}
            yield with_table_index(new_crypta_id_rec, 1)

    debug_log('yielding removed edges for %s' % old_crypta_id)
    for edge in removed_edges:
        edge_rec = edge.orig_rec
        edge_rec['weight'] = edge.weight
        yield with_table_index(edge_rec, 2)  # removed pairs


def index_components(pairs):
    import networkx as nx

    node_map = dict()

    G = nx.Graph()
    for pair in pairs:
        G.add_edge(pair['id1'], pair['id2'])

    for idx, comp in enumerate(nx.connected_components(G)):
        for node in comp:
            node_map[node] = idx

    return node_map


def find_best_component(values, components_of_values):
    """
    Best component is one which contains more occurrences of values
    """
    sum_per_component = defaultdict(int)
    for v in values:
        for component, value_counts in components_of_values[v].iteritems():
            sum_per_component[component] += value_counts
    if len(sum_per_component) > 1:
        sum_per_component_items = sorted(sum_per_component.iteritems(), key=lambda x: x[1], reverse=True)
        best_component_0, max_total_value_counts_0 = sum_per_component_items[0]
        best_component_1, max_total_value_counts_1 = sum_per_component_items[1]
        if max_total_value_counts_0 > max_total_value_counts_1:
            return best_component_0
        else:
            return None
    else:
        return None


def reassign_corner_values(components, removed_edges):
    """
    Clustering split crypta id based on edges density.
    Sometimes we create all-to-all yuid pairs which lead to unreasonable high density in this yuid group.
    It may cause incorrect split where one of split yuid goes to wrong component.
    Here we are trying to reassign these yuid back to the component which it fits best.
    """
    value_component_freq = defaultdict(lambda: defaultdict(int))
    vertices_to_component = dict()

    for component in components:
        for edge in component.edges:
            value = edge.orig_rec.get('id_value')
            if value:
                value_component_freq[value][component] += 1

                vertices_to_component[edge.id1] = component
                vertices_to_component[edge.id2] = component

    reassign_stats = []

    removed_edges = sorted(removed_edges)
    # trying to find best component for removed pairs
    for (id1, id2), multi_edges in itertools.groupby(removed_edges, key=lambda edge: (edge.id1, edge.id2)):
        multi_edges = list(multi_edges)
        edge_values = set()
        for edge in multi_edges:
            value = edge.orig_rec.get('id_value')
            if value:
                edge_values.add(value)

        best_c = find_best_component(edge_values, value_component_freq)

        if best_c:
            c1 = vertices_to_component.get(id1)
            c2 = vertices_to_component.get(id2)

            # if this pair was taken off from it's best component, we can...
            try:
                if c1 == best_c and c2 and c2.vertices_count > 2:
                    # put it back
                    c1.add_edges(multi_edges)
                    # make split in another place
                    removed_count = c2.remove_vertex(id2, cut_limit=1)
                    if removed_count:
                        reassign_stats.append((id2, c2, c1, removed_count))
                elif c2 == best_c and c1 and c1.vertices_count > 2:
                    # put it back
                    c2.add_edges(multi_edges)
                    # make split in another place
                    removed_count = c1.remove_vertex(id1, cut_limit=1)
                    if removed_count:
                        reassign_stats.append((id1, c1, c2, removed_count))
            except:
                raise Exception('%s, %s, %s' % (id1, id2, vertices_to_component))

    return components, reassign_stats


def map_old_crypta_id(rec):
    rec['new_crypta_id'] = rec['crypta_id']
    rec['new_crypta_id_size'] = rec['crypta_id_size']
    rec['crypta_id'] = rec['old_crypta_id']
    rec['crypta_id_size'] = rec['old_crypta_id_size']
    del rec['old_crypta_id']
    del rec['old_crypta_id_size']
    yield rec


def split_by_table_index(rec):
    if 'tmp_table_index' in rec:
        table_index = rec['tmp_table_index']
        del rec['tmp_table_index']
    else:
        table_index = 0
    rec['@table_index'] = table_index
    yield rec


def cluster_vertices_locally(in_edges_table, tmp_yt_folder, clustering_config):
    logging.info('Reading large crypta_id to memory to split locally')
    recs_count = yt.row_count(in_edges_table)
    logging.info('Reading %d recs...' % recs_count)
    # reading only required columns to reduce memory footprint
    recs = yt.read_table(yt.TablePath(in_edges_table,
                                      columns=USED_COLUMNS), raw=False)
    logging.info('Fetching YT iterator...')
    large_crypta_ids = list()
    groupby_crypta_id = itertools.groupby(recs, key=lambda rec: (rec['crypta_id'], rec['crypta_id_size']))
    for (crypta_id, crypta_id_size), recs in groupby_crypta_id:
        logging.info('Trying to cluster crypta id %s' % crypta_id)
        large_crypta_ids.append(crypta_id)

        reduce_key = {'crypta_id': crypta_id, 'crypta_id_size': crypta_id_size}
        out_recs = reduce_split_crypta_id_louvain(reduce_key, recs, local_mode=True, clustering_config=clustering_config)

        yt.write_table(tmp_yt_folder + crypta_id, out_recs, raw=False)
        logging.info('Uploaded components to tmp_table')

        del out_recs
    del recs
    logging.info('Done local clustering')

    return large_crypta_ids


def cluster_vertices(in_vertices_table, in_edges_table,
                     cluster_vertices_folder,
                     out_vertices_table, out_edges_table,
                     clustering_config):

    mr.mkdir(cluster_vertices_folder)
    mr.mkdir(cluster_vertices_folder + 'quality_metrics')

    out_tables = [out_edges_table]
    out_tables += [cluster_vertices_folder + t for t in ['new_crypta_ids',
                                                         'removed_pairs',
                                                         'quality_metrics/origin',
                                                         'quality_metrics/split',
                                                         'corner_reassigns',
                                                         'overspliced_too_much']]
    # cluster small vertices at YT
    yt.run_reduce(partial(reduce_split_crypta_id_louvain, clustering_config=clustering_config),
                  in_edges_table,
                  out_tables,
                  reduce_by=['crypta_id', 'crypta_id_size'],
                  # this sorting is required for groupby inside the reduce to work correctly
                  sort_by=['crypta_id', 'crypta_id_size', 'id1', 'id2'])

    if clustering_config.local_clustering_enabled:
        mr.mkdir(cluster_vertices_folder + 'large')
        mr.mkdir(cluster_vertices_folder + 'large/quality_metrics')
        mr.mkdir(cluster_vertices_folder + 'large/local_splits')

        # cluster large vertices locally
        # assume sorted by ['crypta_id', 'crypta_id_size', 'id1', 'id2']
        clustered_crypta_ids = cluster_vertices_locally(in_edges_table + '_large',
                                                        cluster_vertices_folder + 'large/local_splits/',
                                                        clustering_config)
        if clustered_crypta_ids:
            # yt write op doesn't support writes to several tables, so split it manually
            yt.run_map(split_by_table_index,
                       [cluster_vertices_folder + 'large/local_splits/' + crypta_id
                        for crypta_id in clustered_crypta_ids],
                       [yt.TablePath(t, append=True) for t in out_tables])

    ops = []

    for penalty_metric in ['penalty', 'overmatching_penalty', 'oversplicing_penalty',
                           'split_id_values_penalty', 'split_ratio_penalty',
                           'browsers_penalty', 'devices_penalty', 'regions_penalty',
                           'sex_penalty', 'mobile_only_penalty']:
        ops.append(mr.avg_column(cluster_vertices_folder + 'quality_metrics/origin',
                                 cluster_vertices_folder + 'quality_metrics/origin_' + penalty_metric + '_avg',
                                 column=penalty_metric, sync=False))
        ops.append(mr.avg_column(cluster_vertices_folder + 'quality_metrics/split',
                                 cluster_vertices_folder + 'quality_metrics/split_' + penalty_metric + '_avg',
                                 column=penalty_metric, sync=False))

    utils.wait_all(ops)

    # assign new clustered crypta_ids to old-style vertices
    # assume orig vertices are sorted
    mr.sort_all([
        cluster_vertices_folder + 'new_crypta_ids',
        cluster_vertices_folder + 'removed_pairs'
    ], sort_by='key')
    yt.run_reduce(join_components_to_vertices,
                  [in_vertices_table, cluster_vertices_folder + 'new_crypta_ids'],
                  [out_vertices_table,
                   # exist in pairs, missing in vertices, usually because of OOM crypta ids
                   cluster_vertices_folder + 'seems_missing_in_vertices',
                   cluster_vertices_folder + 'very_strange'],
                  reduce_by='key')

    utils.wait_all([
        yt.run_sort(cluster_vertices_folder + 'quality_metrics/origin', sort_by='overmatching_penalty', sync=False),
        yt.run_sort(cluster_vertices_folder + 'quality_metrics/split', sort_by='oversplicing_penalty', sync=False),
        yt.run_sort(out_edges_table, sort_by=['key'], sync=False),  # join vertices id later
        yt.run_sort(out_vertices_table, sort_by='key', sync=False)
    ])


class ClusteringConfig(object):
    """
    All params of clustering algorithm
    """
    def __init__(self, local_clustering_enabled=True,
                 min_split_threshold=5,
                 max_size_component=15,
                 collapse_multi_edges=False,
                 new_multi_source_approach=True,
                 reassign_corner_values=False,
                 cluster_multi_values_over_human_limit=True,
                 use_region=True, use_sex=True, use_values=False):
        """
        :param local_clustering_enabled: split huge vertices at client side
        :param min_split_threshold: minimal size of crypta id to apply clustering
        :param max_size_component: if component is bigger than this, continue hierarchical clustering
        :param collapse_multi_edges: use merged edge to represent several edges between two ids
        :param new_multi_source_approach: don't count source repeats like login and email from login
        :param reassign_corner_values: try to reassign vertex to correct component in case of controversial split
        :param cluster_multi_values_over_human_limit: if human limit of ids per crypta_id exceeds, cluster one more time
        :param use_region: reduce edge weight when two vertices are in different regions
        :param use_sex: reduce edge weight when two vertices are of different gender
        :param use_values: reduce edge weight based on multi-login (not only) dates
        """
        self.local_clustering_enabled = local_clustering_enabled
        self.min_split_threshold = min_split_threshold
        self.max_size_component = max_size_component
        self.collapse_multi_edges = collapse_multi_edges
        self.new_multi_source_approach = new_multi_source_approach
        self.reassign_corner_values = reassign_corner_values
        self.cluster_multi_values_over_human_limit = cluster_multi_values_over_human_limit
        self.use_region = use_region
        self.use_sex = use_sex
        self.use_values = use_values

    def __repr__(self):
        return str(self.__dict__)

    def __str__(self):
        return str(self.__dict__)


class ClusterVertices(graph_vertices_base.BaseVerticesTask):
    orig_vertices_config = luigi.Parameter(description='vertices config of vertices that need to be clustered')
    clustering_config = luigi.Parameter()
    name = luigi.Parameter(default=config.VERTICES_TYPE_CLUSTER)

    def __init__(self, orig_vertices_config, clustering_config, name=config.VERTICES_TYPE_CLUSTER):
        vertices_config = graph_vertices_base.VerticesConfig(orig_vertices_config.relative_path + name + '/',
                                                             orig_vertices_config.vertices_type + '_' + name,
                                                             orig_vertices_config.date,
                                                             producing_task=self,
                                                             base_path=orig_vertices_config.base_path)

        super(ClusterVertices, self).__init__(vertices_config,
                                              orig_vertices_config=orig_vertices_config,
                                              clustering_config=clustering_config,
                                              name=name)

    def requires(self):
        return super(ClusterVertices, self).requires() + [self.orig_vertices_config.producing_task]

    def create_vertices(self, in_pairs_table, out_vertices_table, out_edges_table):
        orig_vf = self.orig_vertices_config.get_vertices_folder()
        cluster_vf = self.vertices_config.get_vertices_folder()
        cluster_vertices(in_vertices_table=orig_vf + 'vertices',
                         in_edges_table=orig_vf + 'edges',
                         cluster_vertices_folder=cluster_vf,
                         out_vertices_table=out_vertices_table,
                         out_edges_table=out_edges_table,
                         clustering_config=self.clustering_config)

        yt.run_map(map_old_crypta_id,  # to show clustering in crypta viewer in different way
                   out_edges_table,
                   cluster_vf + 'edges_orig')
        yt.run_sort(cluster_vf + 'edges_orig',
                    sort_by=['crypta_id', 'crypta_id_size', 'id1', 'id2'])


    def run(self):
        workdir = self.out_f('vertices_folder')
        mr.mkdir(workdir)
        self.vertices_pipeline(None, workdir)  # cluster vertices are not created from usual pairs

    def output(self):
        cluster_vertices_folder = self.out_f('vertices_folder')
        return [
            yt_luigi.YtTarget(cluster_vertices_folder + 'vertices'),
            yt_luigi.YtTarget(cluster_vertices_folder + 'edges'),
            yt_luigi.YtTarget(cluster_vertices_folder + 'edges_orig'),
            yt_luigi.YtTarget(cluster_vertices_folder + 'quality_metrics/split')
        ]

class AggregateTask(luigi.WrapperTask):
    tasks = luigi.Parameter()

    def requires(self):
        return [self.tasks]


if __name__ == '__main__':
    yt.config.set_proxy(config.MR_SERVER)
    yt.config["tabular_data_format"] = yt.YsonFormat(process_table_index=True)
