#!/usr/bin/env python

from rtcconf import config
import hh_util as util
from hh_daily import HHDailyTask
from matching.yuid_matching.graph_dict import YamrFormatDicts
from utils import mr_utils as mr
from utils import utils
from lib.luigi import yt_luigi
import luigi
import yt.wrapper as yt
import os
from collections import defaultdict
from itertools import groupby
from functools import partial

"""
Prepare daily HH for merging
"""
@yt.aggregator
def map_daily_hh(recs):
    for rec in recs:
        for member in rec['members']:
            out_rec = {
                'key': member['id'],
                'hhid': rec['hhid'],
                'member': member,
                'ip': rec['ip'],
                'rid': member['rid'],
                'date': rec['date'],
            }
            yield out_rec


max_ip_hhs_for_yuid = 30000
def merge_id_data(key, recs):
    device_type = 'm'   # d - desktop, m - mobile, tv - TV
    hh_recs = []
    rid_dates = defaultdict(set)
    all_dates = set()
    total_home_events = 0
    for rec in recs:
        if rec['@table_index'] == 0:
            hh_recs.append(dict(rec))
            if len(hh_recs) > max_ip_hhs_for_yuid:
                return
            rid_dates[rec['rid']].add(rec['date'])
            all_dates.add(rec['date'])
            total_home_events += rec['member']['events_home']
        else:
            if not hh_recs:
                return
            ua_profile = util.parse_tskv(rec['value']).get('ua_profile', '')
            if ua_profile:
                device_type = 'tv' if ua_profile.startswith('d|tv') else 'd' if ua_profile.startswith('d') else 'm'

    # find home region
    home_rid = None
    for rid, dates in rid_dates.iteritems():
        if len(dates) > 0.5*len(all_dates):
            home_rid = rid

    # filter recs by home region
    for rec in hh_recs:
        if rec['rid'] == home_rid:
            hhid = '_'.join([rec['ip'], rec['date'], str(rec['hhid'])])
            member = rec['member']
            del member['rid']
            member['device_type'] = device_type
            if device_type in ['d', 'tv']:
                member['merge_weight'] = total_home_events
            yield {'hhid': hhid, 'member': member}


def make_pairs(key, recs):
    desktop_members, mobile_members = [], []
    for rec in recs:
        if rec['member']['device_type'] in ['d', 'tv']:
            desktop_members.append(rec['member'])
        else:
            mobile_members.append(rec['member'])

    # output desktop-mobile pairs
    for desktop_member in desktop_members:
        for mobile_member in mobile_members:
            yield {'mobile_member': mobile_member,
                   'desktop_id': desktop_member['id'],
                   'desktop_id_type': desktop_member['id_type'],
                   'weight': 1,
                   '@table_index': 1}

    # output desktop-only groups
    home_total, work_total = 0, 0
    for desktop_member in desktop_members:
        home_total += desktop_member['events_home']
        work_total += desktop_member['events_work']
    if home_total > work_total:
        for desktop_member in desktop_members:
            yield {'id': desktop_member['id'],
                   'id_type': desktop_member['id_type'],
                   'self_member': desktop_member,
                   'members': desktop_members,
                   '@table_index': 0}


# -- UNITE DESKTOP HH --
def unite_find_max(key, recs):
    """
    Find member with maximum weight (compared to neighbours) and make it center of new HH.
    """
    merged_member = dict()
    neighbours = set()
    for rec in recs:
        for member in rec['members']:
            if util.get_weight(member) > util.get_weight(rec['self_member']):
                return
            neighbours.add((member['id'], member['id_type']))
        merged_member = util.merge_member(merged_member, rec['self_member'])
    new_hhid = key['id_type'] + key['id']
    # already_in_hh rec
    yield {'id': key['id'], 'id_type': key['id_type'], 'in_hh': True, 'hhid': new_hhid, '@table_index': 0}
    # invitation recs
    for n_id, n_id_type in neighbours:
        yield {'hhid': new_hhid, 'id': n_id, 'id_type': n_id_type, 'center_member': merged_member, '@table_index': 1}
    # for collecting merged HHs
    yield {'hhid': new_hhid, 'gen': 0, 'num_of_inv': 0, 'member': merged_member, '@table_index': 2}


def unite_iteration(key, recs, generation):
    """
    Check invitations, glue self_member to best inviting HH and send further invitations
    Input records (order matters):
     - in_hh recs (in_hh == True)
     - invitation recs (center_member != None)
     - neighbour recs (each rec - daily HH group with neighbours)
    """
    best_center_member = None
    hhid = None
    neighbours = set()
    merged_member = dict()
    num_of_invitations = defaultdict(int)
    for rec in recs:
        if rec.get('in_hh', False):
            return
        elif rec.get('center_member', None) is not None:
            # find best invitation
            if (best_center_member is None) or\
               (util.get_weight(rec['center_member']) > util.get_weight(best_center_member)):
                best_center_member = rec['center_member']
                hhid = rec['hhid']
            num_of_invitations[hhid] += 1
        else:
            if best_center_member is None:
                return
            merged_member = util.merge_member(merged_member, rec['self_member'])
            for member in rec['members']:
                neighbours.add((member['id'], member['id_type']))
    # already_in_hh rec
    yield {'id': key['id'], 'id_type': key['id_type'], 'in_hh': True, 'hhid': hhid, '@table_index': 0}
    # invitation recs
    for n_id, n_id_type in neighbours:
        yield {'hhid': hhid, 'id': n_id, 'id_type': n_id_type, 'center_member': best_center_member, '@table_index': 1}
    # for collecting merged HHs
    yield {'hhid': hhid, 'gen': generation, 'num_of_inv': num_of_invitations[hhid], 'member': merged_member,
           '@table_index': 2}


def collect_households(key, recs):
    """
    Collect hh members on different unite iteration
    Try to make small enough HH by filtering latest iterations, if nesessary
    recs are sorted by (generation, num_of_invitations)
    """
    collected_members = []
    done = False
    for k, rec_portion in groupby(recs, lambda r: (r['gen'], r['num_of_inv'])):
        if not done:
            members_to_add = [dict(rec['member']) for rec in rec_portion]
            if util.check_hh_size(collected_members + members_to_add):
                collected_members += members_to_add
            else:
                done = True
    if collected_members:
        # single smart-tv case: try to glue it to any neighbour HH
        if len(collected_members) == 1 and collected_members[0]['device_type'] == 'tv':
            tv_member = collected_members[0]
            yield {'id': tv_member['id'], 'id_type': tv_member['id_type'], 'tv_member': tv_member, '@table_index': 1}
        else:
            yield {'hhid': key['hhid'], 'members': collected_members, '@table_index': 0}
            # for matching with mobile members
            for member in collected_members:
                yield {'desktop_id': member['id'], 'desktop_id_type': member['id_type'], 'hhid': key['hhid'],
                       '@table_index': 2}


# -- SINGLE TV OPERATIONS --
def find_single_tv_neighbour(key, recs):
    first_rec = next(recs)
    if 'tv_member' in first_rec:
        tv_member = dict(first_rec['tv_member'])
        best_neighbour = None
        for rec in recs:
            for member in rec['members']:
                if best_neighbour is None or util.get_weight(member) > util.get_weight(best_neighbour):
                    best_neighbour = dict(member)
        yield {'desktop_id': best_neighbour['id'],
               'desktop_id_type': best_neighbour['id_type'],
               'tv_member': tv_member}
    else:
        for rec in recs:
            pass


def find_hh_for_single_tv(key, recs):
    first_rec = next(recs)
    if 'tv_member' in first_rec:
        tv_member = dict(first_rec['tv_member'])
        other_recs = [rec['hhid'] for rec in recs]
        if other_recs:
            hhid = other_recs[0]
            yield {'hhid': hhid, 'members': [tv_member], '@table_index': 0}
            # for matching with mobile members
            yield {'desktop_id': tv_member['id'], 'desktop_id_type': tv_member['id_type'], 'hhid': hhid, '@table_index': 1}
    else:
        for rec in recs:
            pass



# -- ADD MOBILE TO HH OPERATIONS --
def filter_mobile_pairs(key, recs):
    hhid = None
    mobile_weights = defaultdict(int)
    mobile_members = defaultdict(dict)
    for rec in recs:
        if 'hhid' in rec:
            hhid = rec['hhid']
        else:
            k = util.id_with_type(rec['mobile_member'])
            mobile_weights[k] += rec['weight']
            mobile_members[k] = util.merge_member(mobile_members[k], rec['mobile_member'])
    if hhid:
        for (mob_id, mob_id_type), mob_member in mobile_members.iteritems():
            yield {'id': mob_id, 'id_type': mob_id_type, 'mobile_member': mob_member, 'hhid': hhid,
                   'weight': mobile_weights[(mob_id, mob_id_type)]}


def best_hhid_for_mobile(key, recs):
    best_weight = 0
    best_hhid = None
    mob_member = None
    for rec in recs:
        if rec['weight'] > best_weight:
            best_weight, best_hhid, mob_member = rec['weight'], rec['hhid'], rec['mobile_member']
    if best_hhid:
        yield {'hhid': best_hhid, 'members': [mob_member]}


def finalize_hh(key, recs):
    members = []
    geo_points = []
    for rec in recs:
        for member in rec['members']:
            if 'geo_pts' in member:
                geo_points += member['geo_pts']
            members.append(util.subdict(member, ['id', 'id_type', 'device_type']))
    lat, lon = None, None
    if geo_points:
        lat, lon = util.find_center(geo_points)
    yield {'hhid': key['hhid'], 'members': members, '@table_index': 0}
    if lat is not None and lon is not None:
        for member in members:
            yield {'hhid': key['hhid'],
                   'id_type': member['id_type'], 'id': member['id'],
                   'lat': lat, 'lon': lon,
                   '@table_index': 1}


class HHMergeMonthTask(yt_luigi.BaseYtTask):
    date = luigi.Parameter()

    def input_folders(self):
        return {
            'hh_daily': config.HH_FOLDER2 + 'daily_hh/',
            'dict': config.GRAPH_YT_DICTS_FOLDER,
        }

    def output_folders(self):
        return {
            'hh': config.HH_FOLDER2,
        }

    def requires(self):
        return [HHDailyTask(date=d)
                for d in utils.get_dates_before(self.date, int(config.STORE_DAYS))] +\
            [YamrFormatDicts(self.date)]

    def run(self):
        tmp = os.path.join(self.out_f('hh'), 'tmp') + '/'
        mr.mkdir(tmp)
        ngenerations = 5

        with yt.TempTable() as tmp_member_data, \
            yt.TempTable() as tmp_members_merged, \
            yt.TempTable() as tmp_neighbours, \
            yt.TempTable() as tmp_mobile_pairs, \
            yt.TempTable() as tmp_in_hh, \
            yt.TempTable() as tmp_invitations, \
            yt.TempTable() as tmp_found_hhs, \
            yt.TempTable() as tmp_desktop_hh, \
            yt.TempTable() as tmp_single_tv, \
            yt.TempTable() as tmp_desk_mob_refs, \
            yt.TempTable() as tmp_single_tv_with_neighbour, \
            yt.TempTable() as tmp_single_tv_hh, \
            yt.TempTable() as tmp_single_tv_mob_refs, \
            yt.TempTable() as tmp_uniq_mob_pairs, \
            yt.TempTable() as tmp_mobile_hh:

            # prepare HH for uniting
            daily_hh_tables = [os.path.join(self.in_f('hh_daily'), d)
                               for d in utils.get_dates_before(self.date, int(config.STORE_DAYS))]
            yt.run_map(map_daily_hh, daily_hh_tables, tmp_member_data,
                       yt_files=['//statbox/statbox-dict-last/geodata4.bin'])

            yt.run_sort(tmp_member_data, sort_by='key')
            yt.run_reduce(merge_id_data, [tmp_member_data, self.in_f('dict')+'yuid_ua'],
                          tmp_members_merged, reduce_by='key')

            yt.run_sort(tmp_members_merged, sort_by='hhid')
            yt.run_reduce(make_pairs, tmp_members_merged,
                          [tmp_neighbours, tmp_mobile_pairs], reduce_by='hhid')

            # unite desktop HH
            yt.run_sort(tmp_neighbours, sort_by=['id_type', 'id'])
            yt.run_reduce(unite_find_max, tmp_neighbours,
                          [tmp_in_hh, tmp_invitations, tmp_found_hhs],
                          reduce_by=['id_type', 'id'])

            for gen in xrange(1, ngenerations):
                mr.sort_all([tmp_in_hh, tmp_invitations], sort_by=['id_type', 'id'])
                yt.run_reduce(partial(unite_iteration, generation=gen),
                              [tmp_in_hh, tmp_invitations, tmp_neighbours],
                              [
                                  yt.TablePath(tmp_in_hh, append=True),
                                  tmp_invitations,
                                  yt.TablePath(tmp_found_hhs, append=True)
                              ],
                              reduce_by=['id_type', 'id'])

            yt.run_sort(tmp_found_hhs, sort_by=['hhid', 'gen', 'num_of_inv'])
            yt.run_reduce(collect_households, tmp_found_hhs, [tmp_desktop_hh, tmp_single_tv, tmp_desk_mob_refs],
                          reduce_by=['hhid'], sort_by=['hhid', 'gen', 'num_of_inv'])

            # add single tv to some neighbour HH
            yt.run_sort(tmp_single_tv, sort_by=['id_type', 'id'])
            yt.run_reduce(find_single_tv_neighbour, [tmp_single_tv, tmp_neighbours], tmp_single_tv_with_neighbour,
                          reduce_by=['id_type', 'id'])

            mr.sort_all([tmp_single_tv_with_neighbour, tmp_desk_mob_refs],
                        sort_by=['desktop_id_type', 'desktop_id'])
            yt.run_reduce(find_hh_for_single_tv, [tmp_single_tv_with_neighbour, tmp_desk_mob_refs],
                          [tmp_single_tv_hh, tmp_single_tv_mob_refs], reduce_by=['desktop_id_type', 'desktop_id'])

            # add mobile members to desktop HH
            mr.sort_all([tmp_mobile_pairs, tmp_single_tv_mob_refs],
                        sort_by=['desktop_id_type', 'desktop_id'])
            yt.run_reduce(filter_mobile_pairs, [tmp_mobile_pairs, tmp_desk_mob_refs, tmp_single_tv_mob_refs],
                          tmp_uniq_mob_pairs, reduce_by=['desktop_id_type', 'desktop_id'])

            yt.run_sort(tmp_uniq_mob_pairs, sort_by=['id', 'id_type'])
            yt.run_reduce(best_hhid_for_mobile, tmp_uniq_mob_pairs, tmp_mobile_hh, reduce_by=['id', 'id_type'])

            mr.sort_all([tmp_desktop_hh, tmp_single_tv_hh, tmp_mobile_hh], sort_by='hhid')
            yt.run_reduce(finalize_hh,
                          [tmp_desktop_hh, tmp_single_tv_hh, tmp_mobile_hh],
                          [self.out_f('hh') + 'merged_hh', self.out_f('hh') + 'hh_geo'],
                          reduce_by='hhid')

        mr.set_generate_date(self.out_f('hh') + 'merged_hh', self.date)
        mr.set_generate_date(self.out_f('hh') + 'hh_geo', self.date)


    def output(self):
        return [
            yt_luigi.YtDateTarget(self.out_f('hh') + 'merged_hh', self.date),
            yt_luigi.YtDateTarget(self.out_f('hh') + 'hh_geo', self.date),
        ]


if __name__ == '__main__':
    import sys

    dt = sys.argv[1]

    yt.config.set_proxy(config.MR_SERVER)
    yt.config["tabular_data_format"] = yt.YsonFormat(process_table_index=True)

    config.HH_FOLDER2 = '//home/crypta/team/shiryaev/test_hh/'
    config.STORE_DAYS = 2

    task = HHMergeMonthTask(dt)

    print 'Starting merge...'

    task.run()

    print 'Done.'
