'''Luigi Task to collect and send metrics
about indevice relying on crypta toloka special assignemnt answers
'''
import json
import logging
import os
import sys
from collections import defaultdict
from collections import deque
from datetime import datetime, timedelta
from functools import partial

import luigi
import yt.wrapper as yt

import radius.radius_metrics_calc as calc
import radius.radius_metrics_mr as radius_mr
from infra.toloka.toloka_task_checker import TolokaDesktopSpliceCollector, TolokaHouseholdCollector
from lib import graphite_sender
from lib.luigi import yt_luigi
from rtcconf import config
from utils import mr_utils as mr
from utils import uat_utils

logger = logging.getLogger('toloka_metrics_indevice')


# !TODO: refactor those two below
def send_stats(stats_table, date, metric_name='toloka_coverage'):
    stats = [json.loads(stat_record) for stat_record in yt.read_table(stats_table, format='json', raw=True)]
    stat_data = []
    for r_stats in stats:
        device_type = r_stats.pop('device_type')
        splice_type = r_stats.pop('splice_type')
        for stat_name, value in r_stats.iteritems():
            stat_data.append(dict(
                device_type=device_type,
                splice_type=splice_type,
                stat_name=stat_name,
                value=value
            ))
    to_graphite(date, stat_data, metric_name)


def to_graphite(date, stat_data, metric_name):
    metrics = []
    for stat in stat_data:
        stat_path = '.'.join([stat['device_type'], stat['splice_type'], stat['stat_name']])
        metrics.append((metric_name, stat_path, stat['value']))
    graphite_sender.to_graphite_sender(metrics, date)


def get_reference_and_cid_splices(refid_yuid_info, crypta_yuid_info,
                                  record_filter=lambda record: True,
                                  id_key='yuid',
                                  reference_id='tuid',
                                  attributes=None):
    attributes = [] if not attributes else attributes
    reference_graph = defaultdict(list)
    for record in refid_yuid_info:
        if record_filter(record):
            yuid = record[id_key]
            refid = record[reference_id]
            assessor_row = {id_key: yuid,
                            reference_id: refid}
            for col in attributes:
                assessor_row[col] = record[col]
            reference_graph[refid].append(assessor_row)
    crypta_graph = defaultdict(list)
    for record in crypta_yuid_info:
        if record_filter(record):
            yuid = record[id_key]
            crypta_id = record['crypta_id']
            crypta_row = {id_key: yuid,
                          'crypta_id': crypta_id}
            for col in attributes:
                crypta_row[col] = record[col]
            crypta_graph[crypta_id].append(crypta_row)

    return reference_graph, crypta_graph


def map_toloka_pairs(rec):
    ua_profile = uat_utils.Ua(rec['user_agent']).to_ua_profile()
    yuid = rec.get('yuid', '')
    if not yuid:
        return

    toloka_uid = rec['user_id']
    yield {
        'yuid': yuid,
        'tuid': toloka_uid,
        'ua_profile': ua_profile,
        'device_type': ua_profile.split('|')[1],
        'event_dt': rec['created'].split('T', 1)[0]
    }


def compute_hit_rate(activity_dates):
    days = []
    total_hits = 0
    for day_activity in activity_dates.values():
        for day, hits in day_activity.items():
            days.append(day)
            total_hits += hits
    max_day = datetime.strptime(max(days), "%Y-%m-%d")
    min_day = datetime.strptime(min(days), "%Y-%m-%d")
    hit_rate = 1.0 * total_hits / (1 + (max_day - min_day).days)
    return hit_rate


def filter_unseen_yuids(key, recs, event_dt_field):

    def get_active_dates(id_log):
        return {dt for dt_log in id_log.values() for dt in dt_log.keys()}

    first_seen = None
    for rec in recs:
        if rec['@table_index'] == 0:
            if rec.get('ip_fp_dates'):
                first_seen = min(get_active_dates(rec['ip_fp_dates']))
                activity = rec['reg_fp_dates']
        elif first_seen and rec[event_dt_field] > first_seen:
            rec['@table_index'] = 0
            rec['hit_rate'] = compute_hit_rate(activity)
            yield rec


def filter_by_yuids_with_columns(rec, yuids, columns):
    yuid = rec['yuid']
    out_rec = {}
    if columns:
        out_rec = dict([(column, rec.get(column))
                        for column in columns
                        if rec.get(column)])
    if yuid in yuids and (not columns or len(columns) == len(out_rec)):
        out_rec['yuid'] = yuid
        yield out_rec


def reduce_add_info(yuid_key, recs):
    lrecs = list(recs)
    yuid_recs = [r for r in lrecs if r['@table_index'] == 0]
    info_recs = [r for r in lrecs if r['@table_index'] == 1]
    if yuid_recs and info_recs:
        yuid_rec = yuid_recs[0]
        ua_profile = info_recs[0]['ua_profile']
        reg_fp_dates = info_recs[0]['reg_fp_dates']
        yuid_rec['ua_profile'] = ua_profile
        yuid_rec['device_type'] = radius_mr.get_device_type_crypta(ua_profile)
        yuid_rec['hit_rate'] = compute_hit_rate(reg_fp_dates)
        yield yuid_rec


def enrich_with_info_and_filter(yuid_table, out_table, workdir, yuid_dict):
    # filter yuid ua info only for relevant yuids. We don't reduce full table for performance reasons
    yuids = set(r['yuid'] for r in yt.read_table(yuid_table, raw=False))
    tmp_info_table = workdir + 'yuid_info_tmp'
    yt.run_map(partial(filter_by_yuids_with_columns,
                       yuids=yuids,
                       columns=['ua_profile', 'reg_fp_dates']),
               yuid_dict,
               tmp_info_table)
    yt.run_merge(tmp_info_table, tmp_info_table, spec={'combine_chunks': True})

    mr.sort_all([yuid_table, tmp_info_table], sort_by='yuid')
    yt.run_reduce(reduce_add_info,
                  [yuid_table, tmp_info_table],
                  out_table, reduce_by='yuid')


def prepare_vertices(rec):
    id_type = rec['id_type']
    if id_type.startswith('yuid'):
        yield {'yuid': rec['key'], 'crypta_id': rec['crypta_id']}


def prepare_yuid_crypta_id_data(vertices_table, target_yuid_table, workdir, yuid_dict):
    workdir = workdir if workdir.endswith('/') else workdir + '/'
    yt.run_map(prepare_vertices,
               vertices_table,
               os.path.join(workdir, 'yuid_cid_all'))
    mr.sort_all([os.path.join(workdir, target_yuid_table),
                 os.path.join(workdir, 'yuid_cid_all')],
                sort_by='yuid')
    yt.run_reduce(mr.filter_left_by_right,
                  [os.path.join(workdir, 'yuid_cid_all'),
                   os.path.join(workdir, target_yuid_table)],
                  [os.path.join(workdir, 'yuid_cid')],
                  reduce_by='yuid')
    yt.run_merge(os.path.join(workdir, 'yuid_cid'),
                 os.path.join(workdir, 'yuid_cid'), spec={'combine_chunks': True})
    # Here we have a list of crypta_ids with at least one yuid from target source
    mr.sort_all([os.path.join(workdir, 'yuid_cid')], sort_by='crypta_id')

    # Add user-agent info
    enrich_with_info_and_filter(os.path.join(workdir, 'yuid_cid'),
                                os.path.join(workdir, 'yuid_cid_final'),
                                workdir, yuid_dict)
    # clean up
    yt.remove(os.path.join(workdir, 'yuid_cid_all'))


def map_devid_yuid_pairs(rec):
    devid, yuid = rec['key'].rsplit('_', 1)
    if '_' in devid:
        # catch bad devids/yuids
        yield {'key': rec['key'], '@table_index': 1}
    else:
        yield {'devid': devid, 'yuid': yuid, '@table_index': 0}


def reduce_unique_by_key(key, _):
    yield key


def prepare_yuid_devid_data(target_yuid_table, workdir, dates):
    workdir = workdir if workdir.endswith('/') else workdir + '/'

    pairs_tables = [os.path.join(config.GRAPH_FOLDER, dt, 'pairs', 'dev_yuid_pairs_' + pair_name) for dt in dates
                    for pair_name in config.DEVID_PAIRS_NAMES_ALL
                    if pair_name.endswith('_' + config.INDEVICE)]

    yt.run_map(map_devid_yuid_pairs,
               pairs_tables,
               [os.path.join(workdir, 'yuid_devid_all'),
                os.path.join(workdir, '__bad_yuid_devid')])
    mr.sort_all([os.path.join(workdir, 'yuid_devid_all'),
                 os.path.join(workdir, target_yuid_table)], sort_by=['yuid', 'devid'])
    yt.run_reduce(reduce_unique_by_key,
                  os.path.join(workdir, 'yuid_devid_all'),
                  os.path.join(workdir, 'yuid_devid_all'),
                  reduce_by=['yuid', 'devid'])
    mr.sort_all([os.path.join(workdir, 'yuid_devid_all'),
                 os.path.join(workdir, target_yuid_table)], sort_by='yuid')
    yt.run_reduce(mr.filter_left_by_right,
                  [os.path.join(workdir, 'yuid_devid_all'),
                   os.path.join(workdir, target_yuid_table)],
                  [os.path.join(workdir, 'yuid_devid')],
                  reduce_by='yuid')
    yt.run_merge(os.path.join(workdir, 'yuid_devid'),
                 os.path.join(workdir, 'yuid_devid'), spec={'combine_chunks': True})
    # Here we have a list of devids with at least one yuid from target source
    mr.sort_all([os.path.join(workdir, 'yuid_devid')], sort_by='devid')
    # clean up
    yt.remove(os.path.join(workdir, 'yuid_devid_all'))


def filter_exact_pairs(rec, toloka_pairs):
    ua_profile_a = rec.get('id1_ua', None)
    ua_profile_b = rec.get('id2_ua', None)
    yuid_a, yuid_b = rec.get('pair').split('_')
    pair = min(yuid_a, yuid_b) + '_' + max(yuid_a, yuid_b)
    if ua_profile_a == ua_profile_b and (pair, ua_profile_a) in toloka_pairs:
        yuid_a, yuid_b = pair.split('_')
        yield {'pair': pair,
               'ua_profile': ua_profile_a}


def write_stats(table, stats):
    yt.write_table(yt.TablePath(table, append=True), [json.dumps(s) for s in stats],
                   format='json',
                   raw=True)


def get_yuid_to_uid(graph):
    yuid_to_uid = {}
    for uid, vertices in graph.items():
        for vertice in vertices:
            yuid_to_uid[vertice['yuid']] = uid
    return yuid_to_uid


def compute_cid_recall(yuids, norm_coeff):
    cid_weight = sum(yuid['hit_rate'] for yuid in yuids)
    vnum = len(yuids)
    recall = cid_weight * (vnum - 1) / vnum / norm_coeff
    return recall


def calculate_users_recalls(reference_graph, crypta_graph):
    yuid_to_cid = get_yuid_to_uid(crypta_graph)
    recalls = {}
    for uid, vertices in reference_graph.items():
        if len(vertices) < 2:
            continue
        uid_weight = sum([vertice['hit_rate'] for vertice in vertices])
        yuids = set([vertice['yuid'] for vertice in vertices])
        crypta_ids = set([yuid_to_cid[yuid] for yuid in yuids
                          if yuid in yuid in yuid_to_cid])
        recalls[uid] = 0.0
        vnum = len(vertices)
        norm_coeff = uid_weight * (vnum - 1) / vnum
        for cid in crypta_ids:
            recalls[uid] += compute_cid_recall(crypta_graph[cid], norm_coeff)
    return recalls


def graph_bootstrap_samples(reference_graph, crypta_graph, num_bootstrap=100):
    import numpy as np
    tuids = reference_graph.keys()
    for _ in range(num_bootstrap):
        sample_tuids = set(np.random.choice(tuids, len(tuids)))
        sample_ref_graph = dict([(key, val) for key, val
                                 in reference_graph.items()
                                 if key in sample_tuids])
        sample_yuids = set([rec['yuid']
                            for splice in sample_ref_graph.values()
                            for rec in splice])
        sample_crypta_graph = defaultdict(list)
        for cid, splice in crypta_graph.items():
            for rec in splice:
                if rec['yuid'] in sample_yuids:
                    sample_crypta_graph[cid].append(rec)
        yield (sample_ref_graph, sample_crypta_graph)


def sample_metrics(reference_graph, crypta_graph):
    users_hit_rates = dict(
        [(tuid, sum([yuid['hit_rate'] for yuid in user]))
         for tuid, user in reference_graph.items()])
    logins_to_metrics = calc.get_metrics_per_login(reference_graph, crypta_graph)
    precisions = []
    # weights = []
    topology_recalls = []
    for lm in logins_to_metrics.values():
        precisions.append((lm.prec_a, lm.prec_b))
        topology_recalls.append((lm.rec_a, lm.rec_b))
        # weights.append(users_hit_rates[login])
    user_recalls = calculate_users_recalls(reference_graph, crypta_graph)
    total_hit_rate = sum([hit_rate
                          for uid, hit_rate in users_hit_rates.items()
                          if uid in user_recalls])
    precision, topology_recall, _ = calc.average_metrics(precisions, topology_recalls)
    recall = 0.0
    for uid, uid_recall in user_recalls.items():
        weight = users_hit_rates[uid] / total_hit_rate
        recall += weight * uid_recall
    stats = {'topology_recall': topology_recall['mean.opt'],
             'recall': recall,
             'precision': precision['mean.opt']}
    return stats


def apply_imperical_scale(stats, device_type):
    if device_type == 'cross':
        return dict((k, v) for k, v in stats.iteritems())
    return stats


def calculate_all_metrics(workdir, tuid_yuid_table, device_type, vertices_type):
    import numpy as np
    yuid_refid_recs = list(yt.read_table(tuid_yuid_table, raw=False))
    yuid_cid_recs = list(
        yt.read_table(os.path.join(workdir, 'yuid_cid_final'), raw=False))
    reference_graph, crypta_graph = get_reference_and_cid_splices(
        yuid_refid_recs,
        yuid_cid_recs,
        reference_id='tuid',
        attributes=['hit_rate'])
    stats = apply_imperical_scale(sample_metrics(reference_graph, crypta_graph), device_type)
    stats['device_type'] = device_type
    stats['splice_type'] = vertices_type
    # also add bootstraped mean estimation
    bootstrap_size = 100
    samples_gen = graph_bootstrap_samples(reference_graph, crypta_graph, bootstrap_size)
    bootstrap_stats = []
    for sample_ref_graph, sample_crypta_graph in samples_gen:
        bootstrap_stats.append(apply_imperical_scale(
            sample_metrics(sample_ref_graph, sample_crypta_graph),
            device_type))
    recall_mean = np.sum(np.array([st['recall'] for
                                   st in bootstrap_stats])) / bootstrap_size
    recall_var = np.sum(np.square(np.array([st['recall']
                                            for st in bootstrap_stats])
                                  - recall_mean)) / (bootstrap_size - 1)
    stats['recall_mean'] = recall_mean
    stats['recall_var'] = recall_var
    return stats


def get_clusters(ref_graph):
    yuid_tids = defaultdict(set)
    tid_yuids = defaultdict(set)
    for tid, splice in ref_graph.items():
        for rec in splice:
            yuid_tids[rec['yuid']].add(tid)
            tid_yuids[tid].add(rec['yuid'])
    tid_clid = dict([(tid, tid) for tid in ref_graph.keys()])
    visited = set()
    queue = deque()
    for next_tid in tid_yuids.keys():
        if next_tid in visited:
            continue
        queue.append(next_tid)
        visited.add(next_tid)
        cluster = set([next_tid])
        while queue:
            tid = queue.pop()
            for yuid in tid_yuids[tid]:
                for another_tid in yuid_tids[yuid]:
                    if another_tid not in visited:
                        cluster.add(another_tid)
                        queue.append(another_tid)
                        visited.add(another_tid)
        cluster_id = min(cluster)
        for tid in cluster:
            tid_clid[tid] = cluster_id
    return tid_clid


def calculate_device_coverage(workdir, tuid_yuid_table, device_type, vertices_type):
    yuid_tuid_recs = list(yt.read_table(tuid_yuid_table, raw=False))
    yuid_devid_recs = list(
        yt.read_table(os.path.join(workdir, 'yuid_devid'), raw=False))
    # dirty hack to reuse the code
    for rec in yuid_devid_recs:
        rec['crypta_id'] = rec['devid']
    toloka_graph, device_graph = get_reference_and_cid_splices(
        yuid_tuid_recs,
        yuid_devid_recs,
        reference_id='tuid')

    toloka_clusters = get_clusters(toloka_graph)
    coverage = 1.0 * len(device_graph) / len(set(toloka_clusters.values()))

    stats = {'device_type': device_type,
             'coverage': coverage,
             'splice_type': vertices_type}
    return stats


def calculate_mobile_metrics(workdir, tuid_yuid_table):
    yuid_refid_recs = list(yt.read_table(tuid_yuid_table, raw=False))
    yuid_devid_recs = list(
        yt.read_table(os.path.join(workdir, 'yuid_devid'), raw=False))
    for rec in yuid_devid_recs:
        rec['crypta_id'] = rec['devid']
    reference_graph, crypta_graph = get_reference_and_cid_splices(
        yuid_refid_recs,
        yuid_devid_recs,
        reference_id='tuid')
    logins_to_metrics = calc.get_metrics_per_login(reference_graph, crypta_graph)
    precisions = [(lm.prec_a, lm.prec_b) for lm in logins_to_metrics.values()]
    recalls = [(lm.rec_a, lm.rec_b) for lm in logins_to_metrics.values()]
    precision, recall, _ = calc.average_metrics(precisions, recalls)
    device_coverage = 1.0 * len(crypta_graph) / len(reference_graph)
    stats = {'device_type': 'mobile',
             'recall': recall['mean.opt'],
             'precision': precision['mean.opt'],
             'device_coverage': device_coverage,
             'splice_type': 'exact_with_devid_only'}
    return stats


class TolokaMetrics(yt_luigi.BaseYtTask):
    '''
    Luigi Task to measure recall.
    '''
    date = luigi.Parameter()
    vertices_config = luigi.Parameter()

    def __init__(self, *args, **kwargs):
        super(TolokaMetrics, self).__init__(*args, **kwargs)
        self.output_folder = os.path.join(self.out_f('toloka_metrics'), self.date)
        self.stats_tables = {
            'daily': os.path.join(self.output_folder, 'daily_stats_' + self.vertices_config.vertices_type),
        }

    def input_folders(self):
        return {
            'dict': config.GRAPH_YT_DICTS_FOLDER,
            'toloka': config.CRYPTA_TOLOKA_FOLDER,
        }

    def output_folders(self):
        return {
            'toloka_metrics': config.CRYPTA_TOLOKA_METRICS_FOLDER,
        }

    def requires(self):
        deps = [self.vertices_config.producing_task]
        if config.CRYPTA_ENV != 'development':
            deps.append(TolokaDesktopSpliceCollector(self.date))
        return deps

    def run(self):
        yuid_dict = self.in_f('dict') + 'yuid_with_all'
        mr.mkdir(self.output_folder)

        if config.CRYPTA_ENV == 'testing':  # to match tolokers via production table
            yuid_dict = '//home/crypta/production/state/graph/dicts/yuid_with_all'

        # list of tuples (report_type, device_type, splice_tables) for daily reports
        splice_stat_config = [
            ('daily', 'desk', [os.path.join(self.in_f('toloka'), 'indevice_desktop', self.date + '-splices')]),
            ('daily', 'mobile', [os.path.join(self.in_f('toloka'), 'indevice_mobile', self.date + '-splices')]),
            ('daily', 'cross', [os.path.join(self.in_f('toloka'), 'cross_device', self.date + '-splices')]),
        ]

        all_stats_table = os.path.join(self.output_folder, 'stats')

        for report_type, device_type, splice_tables in splice_stat_config:
            workdir = os.path.join(self.output_folder,
                                   report_type,
                                   device_type,
                                   self.vertices_config.vertices_type)
            if any(yt.row_count(table) > 0 for table in splice_tables):
                mr.mkdir(workdir)
                tuid_yuid_table = os.path.join(workdir,
                                               'yuid_tuid')
                yt.run_map(map_toloka_pairs,
                           splice_tables,
                           tuid_yuid_table)
                mr.distinct_by(['tuid', 'yuid', 'ua_profile', 'device_type', 'event_dt'],
                               tuid_yuid_table,
                               tuid_yuid_table)
                mr.sort_all([tuid_yuid_table], sort_by='yuid')
                yt.run_reduce(partial(filter_unseen_yuids,
                                      event_dt_field='event_dt'),
                              [yuid_dict, tuid_yuid_table],
                              tuid_yuid_table + '_seen',
                              reduce_by='yuid')
                prepare_yuid_crypta_id_data(self.vertices_config.get_vertices_table(),
                                            tuid_yuid_table + '_seen',
                                            workdir,
                                            yuid_dict)
                stats = calculate_all_metrics(workdir,
                                              tuid_yuid_table + '_seen',
                                              device_type,
                                              self.vertices_config.vertices_type)
                write_stats(self.stats_tables[report_type], [stats])
                write_stats(all_stats_table, [stats])

        if yt.exists(self.stats_tables['daily']):
            send_stats(self.stats_tables['daily'], self.date)

        for table in self.stats_tables.values():
            if not yt.exists(table):
                yt.create_table(table)

    def output(self):
        return [yt_luigi.YtTarget(table, allow_empty=True) for table in self.stats_tables.values()]


class TolokaMetricsDevidCoverage(yt_luigi.BaseYtTask):
    '''
    Luigi Task to measure device coverage.
    '''
    date = luigi.Parameter()
    vertices_config = luigi.Parameter()

    def __init__(self, *args, **kwargs):
        super(TolokaMetricsDevidCoverage, self).__init__(*args, **kwargs)
        self.output_folder = os.path.join(self.out_f('toloka_metrics'), self.date)
        self.stats_tables = {
            'daily': os.path.join(self.output_folder, 'daily_device_stats'),
        }

    def input_folders(self):
        return {
            'dict': config.GRAPH_YT_DICTS_FOLDER,
            'toloka': config.CRYPTA_TOLOKA_FOLDER,
        }

    def output_folders(self):
        return {
            'toloka_metrics': config.CRYPTA_TOLOKA_METRICS_FOLDER,
        }

    def requires(self):
        return [TolokaMetrics(self.date, self.vertices_config), ]

    def run(self):
        mr.mkdir(self.output_folder)
        #        stats_table = os.path.join(self.output_folder, 'device_stats')

        # tuples (stat_report_type, stats_table, dates)
        stat_config = [
            ('daily', self.stats_tables['daily'], [self.date]),
        ]

        device_type = 'mobile'
        workdir = os.path.join(self.output_folder,
                               device_type,
                               self.vertices_config.vertices_type)

        for report_type, stats_table, dates in stat_config:
            splice_tables = [os.path.join(config.CRYPTA_TOLOKA_FOLDER,
                                          'indevice_mobile',
                                          dt + '-splices') for dt in dates]

            if any(yt.row_count(table) > 0 for table in splice_tables):
                mr.mkdir(workdir)
                tuid_yuid_table = os.path.join(workdir, '_'.join(['dev_coverage', report_type, 'yuid_tuid_seen']))
                yt.run_merge(
                    [os.path.join(config.CRYPTA_TOLOKA_FOLDER,
                                  'metrics',
                                  dt,
                                  'daily',
                                  device_type,
                                  self.vertices_config.vertices_type,
                                  'yuid_tuid') for dt in dates],
                    tuid_yuid_table)

                prepare_yuid_devid_data(tuid_yuid_table,
                                        workdir,
                                        dates)
                stats = calculate_device_coverage(workdir,
                                                  tuid_yuid_table,
                                                  'mobile',
                                                  self.vertices_config.vertices_type)
                write_stats(stats_table, [stats])

        if yt.exists(self.stats_tables['daily']):
            send_stats(self.stats_tables['daily'], self.date, metric_name='toloka_device_coverage')

        for table in self.stats_tables.values():
            if not yt.exists(table):
                yt.create_table(table)

    def output(self):
        return [yt_luigi.YtTarget(table, allow_empty=True) for table in self.stats_tables.values()]


class TolokaMetricsHousehold(yt_luigi.BaseYtTask):
    date = luigi.Parameter()

    vertices_table = luigi.Parameter()
    vertices_table_yuid_column = luigi.Parameter()
    vertices_table_cryptaid_column = luigi.Parameter()
    vertices_producing_task = luigi.Parameter()
    vertices_type = luigi.Parameter()

    def __init__(self, *args, **kwargs):
        super(TolokaMetricsHousehold, self).__init__(*args, **kwargs)
        self.output_folder = os.path.join(self.out_f('toloka_metrics'), self.date)
        self.output_table = os.path.join(self.output_folder, 'household_' + self.vertices_type)
        self.toloka_households = os.path.join(self.in_f('toloka'), 'household', self.date + '-households')

    def input_folders(self):
        return {
            'dict': config.GRAPH_YT_DICTS_FOLDER,
            'toloka': config.CRYPTA_TOLOKA_FOLDER,
            'households': config.HH_FOLDER
        }

    def output_folders(self):
        return {
            'toloka_metrics': config.CRYPTA_TOLOKA_METRICS_FOLDER,
            'workdir': os.path.join(self.in_f('toloka'), 'household')
        }

    def requires(self):
        deps = []

        if self.vertices_producing_task is not None:
            deps.append(self.vertices_producing_task)
            deps.append(TolokaMetrics(self.date, self.vertices_producing_task.vertices_config))
            deps.append(TolokaMetricsDevidCoverage(self.date, self.vertices_producing_task.vertices_config))

        if config.CRYPTA_ENV != 'development':
            deps.append(TolokaHouseholdCollector(self.date))

        return deps

    def output(self):
        return yt_luigi.YtTarget(self.output_table)

    def run(self):
        def mk_vertices_mapper(yuids, yuid_column, cryptaid_column):
            yuids_set = set(yuids)

            def mapper(rec):
                yuid = rec[yuid_column]
                if yuid in yuids_set:
                    yield dict(yuid=yuid, crypta_id=rec[cryptaid_column])
            return mapper

        def mk_hh_mapper(yuids):
            yuids_set = set(yuids)

            def mapper(rec):
                hhid = rec['key']
                parts = rec['value'].split(',')
                hh_yuids = set([x.split('/')[1] for x in parts])
                for yuid in yuids_set:
                    if yuid in hh_yuids:
                        yield dict(yuid=yuid, hh_id=hhid)
            return mapper

        mr.mkdir(self.output_folder)

        household_table = os.path.join(self.in_f('households'), 'merged_households')
        vertices_table = self.vertices_table

        yuid_to_cryptaid_tbl = os.path.join(self.out_f('workdir'), self.date + '-cryptaids')
        yuid_to_hhid_tbl = os.path.join(self.out_f('workdir'), self.date + '-hh')

        yuids = set()
        for r in yt.read_table(self.toloka_households):
            for y in r['shared']:
                yuids.add(y)

            for pers in r['personal']:
                for y in pers:
                    yuids.add(y)

        yt.run_map(mk_vertices_mapper(yuids, self.vertices_table_yuid_column, self.vertices_table_cryptaid_column),
                   vertices_table, yuid_to_cryptaid_tbl)
        yt.run_map(mk_hh_mapper(yuids), household_table, yuid_to_hhid_tbl)

        yuid_to_cid = dict()
        yuid_to_hhid = dict()
        for r in yt.read_table(yuid_to_cryptaid_tbl):
            yuid_to_cid[r['yuid']] = r['crypta_id']
        for r in yt.read_table(yuid_to_hhid_tbl):
            yuid_to_hhid[r['yuid']] = r['hh_id']

        shared_cookies_count = 0
        glued_shared_cookies_count = 0
        glued_people_count = 0
        hh_cookies_count = 0
        total_people_count = 0
        total_cookies_count = len(yuids)

        total_toloka_hh_count = 0
        toloka_hh_has_glued_people_count = 0
        toloka_hh_has_glued_shared_count = 0

        for r in yt.read_table(self.toloka_households):
            if len(r['personal']) == 1 and len(r['shared']) == 0:
                continue

            total_toloka_hh_count += 1

            people = []
            for grp in r['personal']:
                cryptaids = set()
                for y in grp:
                    cid = yuid_to_cid.get(y)
                    if cid:
                        cryptaids.add(cid)
                if cryptaids:
                    people.append(cryptaids)
                    
            glued_idxes = set()
            for i in range(len(people)):
                for j in range(i + 1, len(people)):
                    if people[i].intersection(people[j]):
                        glued_idxes.add(i)
                        glued_idxes.add(j)

            total_people_count += len(r['personal'])
            glued_people_count += len(glued_idxes)

            if len(glued_idxes) > 0:
                toloka_hh_has_glued_people_count += 1

            shared_cookies = set(r['shared'])
            has_glued_shared = False
            for y in shared_cookies:
                cid = yuid_to_cid.get(y)
                if cid:
                    for grp in people:
                        if cid in grp:
                            has_glued_shared = True
                            glued_shared_cookies_count += 1
                            break
            shared_cookies_count += len(shared_cookies)

            if has_glued_shared:
                toloka_hh_has_glued_shared_count += 1

            hhid_counts = defaultdict(int)
            for y in set(r['shared']).union(set([x for x in grp for grp in r['personal']])):
                hhid = yuid_to_hhid.get(y)
                if hhid:
                    hhid_counts[hhid] += 1

            for k in hhid_counts:
                if hhid_counts[k] > 1:
                    hh_cookies_count += hhid_counts[k]

        def try_div(x, y):
            try:
                return x / y
            except ZeroDivisionError:
                return 0

        glued_people_ratio = try_div(float(glued_people_count), float(total_people_count))
        glued_shared_cookies_ratio = try_div(float(glued_shared_cookies_count), float(shared_cookies_count))
        hh_coverage = try_div(float(hh_cookies_count), float(total_cookies_count))
        hh_with_glued_people_ratio = try_div(float(toloka_hh_has_glued_people_count), total_toloka_hh_count)
        hh_with_glued_shared_ratio = try_div(float(toloka_hh_has_glued_shared_count), total_toloka_hh_count)

        metric_values = dict(
            glued_people_ratio=glued_people_ratio,
            glued_shared_cookies_ratio=glued_shared_cookies_ratio,
            hh_coverage=hh_coverage,
            hh_with_glued_people_ratio=hh_with_glued_people_ratio,
            hh_with_glued_shared_ratio=hh_with_glued_shared_ratio
        )

        yt.write_table(self.output_table, [metric_values])

        metrics_to_graphite = []
        metric_base_name = 'toloka.household.' + self.vertices_type
        for k in metric_values:
            metrics_to_graphite.append((metric_base_name, k, metric_values[k]))
        graphite_sender.to_graphite_sender(metrics_to_graphite, self.date)

if '__main__' == __name__:
    from matching.human_matching import graph_vertices, graph_clustering
    yt.config.set_proxy('hahn.yt.yandex.net')
   
    exact_vertices_task = graph_vertices.GraphVerticesExact('2018-01-23',
                                                            vertices_type='exact',
                                                            yuid_pairs_folder='pairs/')
    exact_cluster = graph_clustering.ClusterVertices(exact_vertices_task.vertices_config, graph_clustering.ClusteringConfig()) 

    hh = TolokaMetricsHousehold('2018-01-25', exact_cluster.vertices_config)

    hh.run()
