import itertools
from functools import partial

import luigi
import yt.wrapper as yt

import graph_pair_utils
from lib.luigi import yt_luigi
from matching.yuid_matching import graph_dict
from matching.yuid_matching.enrich import org_emails_classify
from rtcconf import config
from utils import mr_utils as mr
from utils import utils

OOM_CHECK = 1000


def expands_source_types(rec):
    pair_type = rec['pair_type']
    # only yuid-yuid can be cross source now
    if pair_type == 'y_y':
        sources1 = set(rec['yuid1_sources'])
        sources2 = set(rec['yuid2_sources'])

        common_sources = sources1.intersection(sources2)
        if common_sources:
            for s in common_sources:
                rec['source_type'] = s
                yield rec
        else:
            rec['source_type'] = config.PAIR_TYPE_CROSS_SOURCE
            yield rec

    elif pair_type == 'd_y':
        yield rec
    else:
        raise Exception('Unsupported pair_type %s' % pair_type)


def get_n_most_active(yuids_dates, n):
    top_activity = sorted(yuids_dates.iteritems(), key=lambda kv: kv[1], reverse=True)
    return [yuid for yuid, dates_count in top_activity[:n]]


def reduce_pairs_by_limit(id_value_key, recs, pair_type, store_days):
    id_type = pair_type.id_type
    id_value = id_value_key['id_value']
    yuids_per_id_strict_limit = config.YUID_PAIR_TYPES_DICT[id_type].yuids_per_id_strict_limit
    yuids_per_id_soft_limit = config.YUID_PAIR_TYPES_DICT[id_type].yuids_per_id_soft_limit

    active_yuids = 0
    active_yuids_limit_exceed = False
    yuids_dates = dict()
    yuids_sources = dict()
    id_sources = set()
    organization = 0
    yuid_activities = list()
    private_mode = list()

    for rec in recs:
        yuid = rec['yuid']

        if pair_type.filter_private_mode and rec['ip_activity_type'] != 'active':
            private_mode.append(yuid)
            # TODO: think of how it would affect limits below
            continue

        organization += int(rec.get('organization', 0))

        yuid_hit_dates = len(rec['all_ip_dates'])
        yuids_dates[yuid] = yuid_hit_dates
        if not rec.get('webview', False) and yuid_hit_dates > 1:  # hardcode: assume active yuid has more than 1 day of activity
            # skip webview in active yuids counter
            active_yuids += 1

        if pair_type.is_aggregate():
            yuid_sources = rec[id_type + '_sources'][id_value]  # one of possible source types
        else:
            yuid_sources = pair_type.source_types  # only possible single source type

        yuids_sources[yuid] = yuid_sources
        id_sources.update(yuid_sources)

        yuid_id_dates = rec[pair_type.name() + '_dates'][id_value]
        yuid_activity = graph_dict.IdActivity(yuid, [], dates_activity=yuid_id_dates)
        yuid_activities.append(yuid_activity)

        if active_yuids > yuids_per_id_strict_limit:
            active_yuids_limit_exceed = True
            break

    def yuid_rec(yuid, table_index):
        return {'key': yuid, 'id_value': id_value,
                'id_type': id_type, 'source_types': yuids_sources.get(yuid),
                '@table_index': table_index}

    def id_value_rec(yuids_count, table_index):
        return {'id_value': id_value, 'yuids_count': yuids_count,
                'id_type': id_type, 'source_types': list(id_sources), '@table_index': table_index}

    # TODO: unify with graph_pretty
    def pair_rec(yuid1, yuid2, table_index):
        return graph_pair_utils.yuid_yuid_pair_rec(
            yuid1=yuid1,
            yuid2=yuid2,
            id_value=id_value,
            id_type=id_type,
            pair_type=pair_type,
            yuid1_sources=yuids_sources.get(yuid1),
            yuid2_sources=yuids_sources.get(yuid2),
            yuid1_dates=yuids_dates[yuid1],
            yuid2_dates=yuids_dates[yuid2],
            table_index=table_index
        )

    for yuid in private_mode:
        yield yuid_rec(yuid, 6)

    if organization > 0:
        # yuids with organization id
        for rec in recs:
            yield yuid_rec(rec['yuid'], 5)

    elif active_yuids_limit_exceed:
        yuids_count = 0
        rest_yuids_generator = (r['yuid'] for r in recs)
        all_yuids = itertools.chain(yuids_dates.keys(), rest_yuids_generator)
        for idx, yuid in enumerate(all_yuids):
            if idx > OOM_CHECK:
                break
            yuids_count += 1
            # to estimate the number of yuids thrown by strict limit
            yield yuid_rec(yuid, 1)

        # to estimate the number of ids thrown by strict limit
        yield id_value_rec(yuids_count, 3)

    elif len(yuids_dates) == 1:
        # to estimate the number of single-yuid ids
        yuid = yuids_dates.keys()[0]
        yield yuid_rec(yuid, 4)


    elif len(yuids_dates) > 1:
        # get top most active below the soft limit
        # yuids = sorted(get_n_most_active(yuids_dates, yuids_per_id_strict_limit))  # previous top by all dates
        activites_weights = graph_dict.calculate_freq_weights(yuid_activities, store_days)
        
        top_yuids = sorted(yuid_a.id_value for yuid_a, weight in activites_weights[:yuids_per_id_soft_limit])
        soft_yuids_thrown = sorted(yuid_a.id_value for yuid_a, weight in activites_weights[yuids_per_id_soft_limit:])

        for yuid in soft_yuids_thrown:
            # to estimate the number of yuids thrown by soft limit
            yield yuid_rec(yuid, 2)

        # and make pairs of it
        for yuid1, yuid2 in itertools.combinations(top_yuids, 2):
            yield pair_rec(yuid1, yuid2, 0)


class GraphPairs(yt_luigi.BaseYtTask):

    date = luigi.Parameter()

    def input_folders(self):
        return {
            'dict': config.GRAPH_YT_DICTS_FOLDER
        }

    def output_folders(self):
        return {
            'pairs': config.YT_OUTPUT_FOLDER + self.date + '/pairs/',
        }

    def requires(self):
        return [
            graph_dict.YuidAllIdDictsTask(self.date),
            org_emails_classify.OrgEmailsClassifyTask(self.date)]

    def run(self):
        pairs_folder = self.out_f('pairs')
        mr.mkdir(pairs_folder)
        mr.mkdir(pairs_folder + 'limit')
        mr.mkdir(pairs_folder + 'single')
        mr.mkdir(pairs_folder + 'smart')
        mr.mkdir(pairs_folder + 'active')
        mr.mkdir(pairs_folder + 'orgs')
        mr.mkdir(pairs_folder + 'private')

        all_tables = []
        ops = []
        for pair in config.YUID_PAIR_TYPES_EXACT + config.YUID_PAIR_TYPES_FOR_YUID_WITH_ALL:
            dict_table = self.in_f('dict') + 'yuid_with_id_' + pair.name()
            out_pair_tables = [
                pairs_folder + 'yuid_pairs_' + pair.id_type,
                pairs_folder + 'limit/yuids_' + pair.id_type,
                pairs_folder + 'limit/yuids_soft_' + pair.id_type,
                pairs_folder + 'limit/ids_' + pair.id_type,
                pairs_folder + 'single/yuids_' + pair.id_type,
                pairs_folder + 'orgs/yuids_' + pair.id_type,
                pairs_folder + 'private/yuids_' + pair.id_type]
            all_tables += out_pair_tables
            ops.append(yt.run_reduce(
                partial(reduce_pairs_by_limit, pair_type=pair, store_days=int(config.STORE_DAYS)),
                dict_table,
                out_pair_tables,
                reduce_by='id_value', sync=False))
        utils.wait_all(ops)

        mr.merge_chunks_all(all_tables)

    def output(self):
        out = []
        for pair in config.YUID_PAIR_TYPES_EXACT + config.YUID_PAIR_TYPES_FOR_YUID_WITH_ALL:
            out.append(
                yt_luigi.YtTarget(
                    self.out_f('pairs') + 'yuid_pairs_' + pair.id_type,
                    allow_empty=not pair.required
                )
            )
        return out


if __name__ == '__main__':
    yt.config.set_proxy(config.MR_SERVER)

    # dt = sys.argv[1]
    dt = '2016-08-10'

    workdir = '//home/crypta/team/artembelov/smart_limit_paris/'
    mr.mkdir(workdir)

    import smart_runner
    import graph_pairs

    smart_runner.run_isolated(workdir, dt,
                              graph_pairs.GraphPairs,
                              graph_pairs.GraphPairs, date=dt)
