#!/usr/bin/env python

from rtcconf import config
import hh_util as util
from hh_merge import HHMergeMonthTask
from matching.yuid_matching.enrich.puid_yuid_passport import ExpandPuidYuidMatching
from matching.yuid_matching.graph_merge_month import IncrementalDayAndDumpMergeTask
from matching.human_matching.graph_vertices import CopyVerticesToDict
from utils import mr_utils as mr
from lib.luigi import yt_luigi
import luigi
import yt.wrapper as yt
import os
from collections import defaultdict


def map_merged_hh(rec):
    for member in rec['members']:
        yield {'hhid': rec['hhid'], 'id_type': member['id_type'], 'id': member['id']}


# vertices
# --> ('y'yuid | | 'c'cid)
def map_vertices(rec):
    tskv = util.parse_tskv(rec['value'])
    if tskv['d'] in ['y', 'd'] and len(rec['key']) > 3 and tskv['c']:
        yield {'cid': tskv['c'], 'id_type': tskv['d'], 'id': rec['key']}


# ( yuid(s) | id_type | id_value | id_count )
def prepare_ext_ids(rec):
    yield {
        'id_type': 'y',
        'id': rec['yuid'],
        'ext_id': rec['id_type'] + '_' + rec['id_value'],
        'ext_id_count': rec['id_count']
    }


def join_id_data(key, recs):
    hhid = None
    cid = None
    ext_ids = defaultdict(int)
    for rec in recs:
        if rec.get('hhid'):
            hhid = rec['hhid']
        elif rec.get('cid'):
            cid = rec['cid']
        elif rec.get('ext_id'):
            ext_ids[rec['ext_id']] += rec['ext_id_count']
    ext_id_data = None
    if len(ext_ids) > 0:
        ext_id_data = max(ext_ids.iteritems(), key=lambda x: x[1])
    if cid is None and hhid is None:
        yield {'id_type': key['id_type'], 'id': key['id'], 'ext_id': ext_id_data[0], '@table_index': 2}
    else:
        yield {'id_type': key['id_type'],
               'id': key['id'],
               'hhid': hhid,
               'cid': cid,
               'ext_id': ext_id_data,
               '@table_index': 0 if cid else 1}


def collect_cid(key, recs):
    rec_list = [rec for rec in recs]
    hhid = next((rec['hhid'] for rec in rec_list if rec.get('hhid')), None)
    if len(rec_list) <= 10:
        if hhid is None:
            hhid = 'c' + key['cid']
        for rec in rec_list:
            # here we join yuids not in HH to some HH using cryptaid
            # if there's no HH for any yuid in cryptaid - create new HH (containing all yuids connected to cryptaid)
            if not rec.get('hhid'):
                rec['hhid'] = hhid
            yield rec
    else:
        # just output recs from HH
        for rec in rec_list:
            if rec.get('hhid'):
                yield rec


# key = [hhid]
def max_ext_id_for_hh(key, recs):
    rec_list = [rec for rec in recs]
    max_ext_id = max([rec['ext_id'] for rec in rec_list if rec.get('ext_id')] or [(None, 0)], key=lambda x: x[1])[0]
    if max_ext_id is not None:
        for rec in rec_list:
            rec['ext_id'] = max_ext_id
    for rec in rec_list:
        rec['@table_index'] = 0 if max_ext_id is not None else 1
        yield rec


# key = [ext_id]
def collect_ext_id(key, recs):
    hhs = defaultdict(list)
    ids = []
    overfloated = False
    for rec in recs:
        if overfloated:
            if rec.get('hhid'):
                yield util.subdict(rec, ['id', 'id_type', 'hhid'])
        else:
            id_data = util.subdict(rec, ['id', 'id_type'])
            if rec.get('hhid'):
                hhs[rec['hhid']].append(id_data)
            ids.append(id_data)
            if not util.check_hh_size(ids):
                overfloated = True
                for hhid in hhs:
                    for out_rec in hhs[hhid]:
                        yield dict(out_rec, hhid=hhid)
    if not overfloated:
        new_hhid = 'e' + key['ext_id']
        for out_rec in ids:
            yield dict(out_rec, hhid=new_hhid)


#def collect_members(key, recs):
#    yield {'hhid': key['hhid'], 'members': [(rec['id_type'], rec['id']) for rec in recs]}


# ---------------------- STABLE 32bit HHID --------------------
def split_old_hhid(rec):
    yield {'hi': rec['hhid'] >> 16, 'low': rec['hhid'] & 0xFFFF, 'join_order': 0}


# [id | id_type | hhid | ti = 0] - old
# [id | id_type | hhid | ti = 1] - new
# --> (new_hhid | id | id_type | old_hhid), old_hhid may be None
def find_old_hhid(key, recs):
    hhids = dict((rec['@table_index'], rec['hhid']) for rec in recs)
    if 1 in hhids:
        yield {'id_type': key['id_type'], 'id': key['id'], 'new_hhid': hhids[1], 'old_hhid': hhids.get(0, None)}


# (new_hhid | id | id_type | old_hhid), old_hhid may be None
# --> (old_hhid(best intersection) | new_hhid | members)    , index=0
# --> (hash(new_hhid) | 2 | members)               , index=1, rejected, no intersection with old hh
def best_old_hhid(key, recs):
    old_counts = defaultdict(int)
    members = []
    for rec in recs:
        if rec.get('old_hhid'):
            old_counts[rec['old_hhid']] += 1
        members.append((rec['id_type'], rec['id']))
    if len(old_counts) > 0:
        max_old_hhid, _ = max(old_counts.iteritems(), key=lambda x: x[1])
        yield {'hhid': max_old_hhid, 'new_hhid': key['new_hhid'], 'members': members, '@table_index': 0}
    else:
        yield {'hi': hash(key['new_hhid'])&0xFFFF, 'join_order': 2, 'members': members, '@table_index': 1}


# (old_hhid(best intersection) | new_hhid | members)
# ( old_hhid | | members )
# --> ( old_hhid_hi | old_hhid_lo | 1 | uid,...)    , index=0
# --> ( hash(new_hhid) | 2 | members)               , index=1, rejected
def best_new_hhid(key, recs):
    old_hh_members = set()
    new_hhs_members = dict()
    for rec in recs:
        if rec.get('new_hhid'):
            new_hhs_members[rec['new_hhid']] = set((id_type, id) for id_type, id in rec['members'])
        else:
            old_hh_members.add((rec['id_type'], rec['id']))
    if len(new_hhs_members) > 0:
        best_new_hh, best_new_hh_members = max(new_hhs_members.iteritems(), key=lambda x: len(x[1] & old_hh_members))
        yield {
            'hi': int(key['hhid'])>>16,
            'low': int(key['hhid']) & 0xFFFF,
            'join_order': 1,
            'members': list(best_new_hh_members),
            '@table_index': 0
        }
        del new_hhs_members[best_new_hh]
        for hhid, members in new_hhs_members.iteritems():
            yield {'hi': hash(hhid) & 0xFFFF, 'join_order': 2, 'members': list(members), '@table_index': 1}


# (hhid_hi | hhid_lo | 0 )         list of old_hhid
# (hhid_hi | hhid_lo | 1 | members)    inheritors of old_hhid
# (hash_16b | 2 | members)          new hhs
# --> (32b_hhid | members), ti=0
# --> (32b_hhid | id_type | id), ti=1
# reduce_by: hi, sort_by: [hi, join_order]
def generate_hhid(key, recs):
    hi = key['hi'] << 16
    ids = set()
    curLo = 0
    for rec in recs:
        if rec['join_order'] == 0:
            ids.add(rec['low'])
        elif rec['join_order'] == 1:
            ids.add(rec['low'])
            yield {'hhid': hi | rec['low'], 'members': rec['members'], '@table_index': 0}
            for id_type, id in rec['members']:
                yield {'hhid': hi | rec['low'], 'id_type': id_type, 'id': id, '@table_index': 1}
        elif rec['join_order'] == 2:
            while curLo in ids:
                curLo += 1
                if curLo > 0xFFFF:
                    pass # FIXME: throw an error; postprocess overfloated ids
            ids.add(curLo)
            yield {'hhid': hi | curLo, 'members': rec['members'], '@table_index': 0}
            for id_type, id in rec['members']:
                yield {'hhid': hi | curLo, 'id_type': id_type, 'id': id, '@table_index': 1}
            curLo += 1


def prepare_new_hhs(key, recs):
    members = [(rec['id_type'], rec['id']) for rec in recs]
    yield {'hi': hash(key['hhid'])&0xFFFF, 'join_order': 2, 'members': members}


class HHEnrichTask(yt_luigi.BaseYtTask):
    """
    This task makes some more households and enlarge some of existing
    by uniting over cryptaid, yandex logins, vk id and mail.ru logins.

    Motivation for all this - to increase coverage for ads.
    Result table (enriched_hh) shouldn't be used for matching purposes
    (because it already containes information about edges of matching graph).
    """
    date = luigi.Parameter()

    def input_folders(self):
        return {
            'hh': config.HH_FOLDER2,
            'dict': config.GRAPH_YT_DICTS_FOLDER,
            'yuid_raw_month': config.GRAPH_YT_DICTS_FOLDER + 'yuid_raw/',
        }

    def output_folders(self):
        return {
            'hh': config.HH_FOLDER2,
        }

    def requires(self):
        return [
            HHMergeMonthTask(date=self.date),
            ExpandPuidYuidMatching(date=self.date),
            IncrementalDayAndDumpMergeTask(date=self.date),
        ]

    def run(self):
        tmp = os.path.join(self.out_f('hh'), 'tmp') + '/'
        mr.mkdir(tmp)

        merged_hh = os.path.join(self.in_f('hh'), 'merged_hh')
        vertices = os.path.join(self.in_f('dict'), 'exact_vertices')
        # FIXME: check that all this tables are ok to use
        ext_ids = [os.path.join(self.in_f('yuid_raw_month'), t) for t in [
            'yuid_with_login_passport_server',
            'yuid_with_vk_fp',
            'yuid_with_vk_watch_log',
            'yuid_with_vk_barlog',
            'yuid_with_email_mailru',
            'yuid_with_email_wl_mailru',
        ]]

        with yt.TempTable() as tmp_new_hh_reversed:
            with yt.TempTable() as tmp_hh_members, \
                    yt.TempTable() as tmp_vertices, \
                    yt.TempTable() as tmp_ext_ids, \
                    yt.TempTable() as tmp_with_cid, \
                    yt.TempTable() as tmp_without_cid, \
                    yt.TempTable() as tmp_cid_united, \
                    yt.TempTable() as tmp_with_max_ext_id, \
                    yt.TempTable() as tmp_without_ext_id, \
                    yt.TempTable() as tmp_without_cid_and_hhid, \
                    yt.TempTable() as tmp_ext_id_hh:

                yt.run_map(map_merged_hh, merged_hh, tmp_hh_members)

                join_id_data_src = [tmp_hh_members]

                if yt.exists(vertices):
                    yt.run_map(map_vertices, vertices, tmp_vertices)
                    join_id_data_src.append(tmp_vertices)

                if all(yt.exists(t) for t in ext_ids):
                    yt.run_map(prepare_ext_ids, ext_ids, tmp_ext_ids)
                    join_id_data_src.append(tmp_ext_ids)

                mr.sort_all(join_id_data_src, sort_by=['id_type', 'id'])
                yt.run_reduce(join_id_data,
                              join_id_data_src,
                              [tmp_with_cid, tmp_without_cid, tmp_without_cid_and_hhid],
                              reduce_by=['id_type', 'id'])

                yt.run_sort(tmp_with_cid, sort_by=['cid'])
                yt.run_reduce(collect_cid, tmp_with_cid, tmp_cid_united, reduce_by=['cid'])

                mr.sort_all([tmp_cid_united, tmp_without_cid], sort_by=['hhid'])
                yt.run_reduce(max_ext_id_for_hh,
                              [tmp_cid_united, tmp_without_cid],
                              [tmp_with_max_ext_id, tmp_without_ext_id],
                              reduce_by=['hhid'])

                mr.sort_all([tmp_with_max_ext_id, tmp_without_cid_and_hhid], sort_by=['ext_id'])
                yt.run_reduce(collect_ext_id,
                              [tmp_with_max_ext_id, tmp_without_cid_and_hhid],
                              tmp_ext_id_hh,
                              reduce_by=['ext_id'])

                yt.run_merge([tmp_without_ext_id, tmp_ext_id_hh],
                             tmp_new_hh_reversed,
                             spec={'combine_chunks': True})

            # generate stable 32-bit hhid
            if yt.exists(self.out_f('hh')+'enriched_hh'):
                with yt.TempTable() as tmp_old_hhid, \
                        yt.TempTable() as tmp_stable_id_1, \
                        yt.TempTable() as tmp_stable_id_2, \
                        yt.TempTable() as tmp_stable_id_3, \
                        yt.TempTable() as tmp_rejected1, \
                        yt.TempTable() as tmp_rejected2, \
                        yt.TempTable() as tmp_enriched_hh_reversed:

                    yt.run_map(split_old_hhid, self.out_f('hh')+'enriched_hh', tmp_old_hhid)

                    yt.run_sort(tmp_new_hh_reversed, sort_by=['id', 'id_type'])
                    yt.run_reduce(find_old_hhid,
                                  [self.out_f('hh') + 'enriched_hh_reversed', tmp_new_hh_reversed],
                                  tmp_stable_id_1,
                                  reduce_by=['id', 'id_type'])

                    yt.run_sort(tmp_stable_id_1, sort_by=['new_hhid'])
                    yt.run_reduce(best_old_hhid, tmp_stable_id_1, [tmp_stable_id_2, tmp_rejected1], reduce_by=['new_hhid'])

                    yt.run_sort(tmp_stable_id_2, sort_by=['hhid'])
                    yt.run_sort(self.out_f('hh')+'enriched_hh_reversed', tmp_enriched_hh_reversed, sort_by=['hhid'])
                    yt.run_reduce(best_new_hhid,
                                  [tmp_stable_id_2, tmp_enriched_hh_reversed],
                                  [tmp_stable_id_3, tmp_rejected2],
                                  reduce_by=['hhid'])

                    mr.sort_all([tmp_old_hhid, tmp_stable_id_3, tmp_rejected1, tmp_rejected2],
                                sort_by=['hi', 'join_order'])
                    yt.run_reduce(generate_hhid,
                                  [tmp_old_hhid, tmp_stable_id_3, tmp_rejected1, tmp_rejected2],
                                  [self.out_f('hh') + 'enriched_hh', self.out_f('hh') + 'enriched_hh_reversed'],
                                  reduce_by=['hi'],
                                  sort_by=['hi', 'join_order'])

            else:
                # first time
                with yt.TempTable() as tmp_new_hhs:
                    yt.run_sort(tmp_new_hh_reversed, sort_by=['hhid'])
                    yt.run_reduce(prepare_new_hhs,
                                  [tmp_new_hh_reversed],
                                  tmp_new_hhs,
                                  reduce_by=['hhid'])

                    yt.run_sort(tmp_new_hhs, sort_by=['hi', 'join_order'])
                    yt.run_reduce(generate_hhid,
                                  [tmp_new_hhs],
                                  [self.out_f('hh') + 'enriched_hh', self.out_f('hh') + 'enriched_hh_reversed'],
                                  reduce_by=['hi'],
                                  sort_by=['hi', 'join_order'])

        yt.run_sort(self.out_f('hh') + 'enriched_hh_reversed', sort_by=['id', 'id_type'])

        mr.set_generate_date(self.out_f('hh') + 'enriched_hh', self.date)
        mr.set_generate_date(self.out_f('hh') + 'enriched_hh_reversed', self.date)

    def output(self):
        return [yt_luigi.YtDateTarget(self.out_f('hh') + 'enriched_hh', self.date),
                yt_luigi.YtDateTarget(self.out_f('hh') + 'enriched_hh_reversed', self.date)]


if __name__ == '__main__':
    import sys

    dt = sys.argv[1]

    yt.config.set_proxy(config.MR_SERVER)
    yt.config["tabular_data_format"] = yt.YsonFormat(process_table_index=True)

    config.HH_FOLDER2 = '//home/crypta/team/shiryaev/test_hh/'
    config.STORE_DAYS = 2

    task = HHEnrichTask(dt)

    print 'Starting enriching HH...'

    task.run()

    print 'Done.'
