import os
from collections import defaultdict
from yt.wrapper import TablePath, with_context
from datacloud.geo.lib import general, geo_log_utils


_SOCDEM_KEYS = ['age_segment', 'gender', 'income_5_segment']


def olap_transformer_with_counters(yt_client, yuid_to_geo_id_table, yuid_to_socdem_table, audience_dict, output_table):
    with yt_client.TempTable() as tmp_table:
        yt_client.run_map(
            OlapMapperWithCounters(audience_dict),
            yuid_to_socdem_table,
            tmp_table,
        )
        yt_client.run_sort(tmp_table, sort_by='yuid')
        yt_client.run_reduce(
            geo_id_to_cat_reducer_with_counters,
            [tmp_table, yuid_to_geo_id_table],
            tmp_table,
            reduce_by='yuid',
        )
        yt_client.run_map_reduce(
            None,
            state_counters_reducer,
            tmp_table,
            output_table,
            reduce_by=['geo_id', 'cat'],
        )
        yt_client.run_sort(output_table, sort_by=['geo_id', 'cat'])


class OlapMapperWithCounters(object):
    def __init__(self, audience_dict):
        """
        audience_dict: {
            'audience_segments': {7856626: 'shanson',  7720858: 'rock'},
            'longterm_interests': {170: 'dog'},
            'heuristic_common': {1090: 'japan-food'},
            'heuristic_segments': {549: 'android'},
        }
        """
        self._audience_dict = audience_dict
        keys = ['audience_segments', 'longterm_interests', 'heuristic_common']
        self._audience_sets = {key: set(audience_dict.get(key)) for key in keys}

    # To correctly process empty YT records
    def _extract(self, rec, key, default):
        if key not in rec:
            return default
        item = rec[key]
        if item is None:
            return default
        return item

    def __call__(self, rec):
        socdem = rec['exact_socdem']
        if socdem and all([key in socdem for key in _SOCDEM_KEYS]):
            age = socdem['age_segment']
            gender = socdem['gender']
            income = socdem['income_5_segment']
            cat = '#'.join([age, gender, income])
            counters = {'count': 1}

            for key in self._audience_sets:
                interests = set(self._extract(rec, key, []))
                for item in (self._audience_sets[key] & interests):
                    counters[self._audience_dict[key][item]] = 1

            yield {'yuid': rec['yuid'], 'cat': cat, 'counters': counters}


@with_context
def geo_id_to_cat_reducer_with_counters(_, recs, context):
    cat, counters = None, None
    for rec in recs:
        if context.table_index == 0:
            cat = rec['cat']
            counters = rec['counters']
        elif cat:
            yield {'geo_id': rec['geo_id'], 'cat': cat, 'counters': counters}


def state_counters_reducer(_, recs):
    counters = defaultdict(lambda: 0)
    for rec in recs:
        curr_counters = rec['counters']
        for key in curr_counters:
            counters[key] += curr_counters[key]
    yield {'geo_id': rec['geo_id'], 'cat': rec['cat'], 'counters': counters}


def get_polygon_base(yt_client, input_table, output_table):
    yql_token = os.environ['YQL_TOKEN']
    key = 'geo_id'
    resolutions = [9, 8, 7, 6]

    with yt_client.TempTable() as tmp_upscaled_table, \
         yt_client.TempTable() as tmp_addr_table,     \
         yt_client.TempTable() as tmp_unique_geo_id_table:

        yt_client.run_map_reduce(
            None,
            general.unique_reducer,
            TablePath(input_table, columns=['geo_id']),
            tmp_unique_geo_id_table,
            reduce_by='geo_id',
        )
        geo_log_utils.upscale_h3(yt_client, tmp_unique_geo_id_table, tmp_upscaled_table, resolutions)
        geo_log_utils.enhance_with_boundary(yt_client, tmp_upscaled_table, tmp_upscaled_table, resolutions)
        yt_client.run_sort(tmp_upscaled_table, sort_by=key)

        geo_log_utils.geo_hash_to_center(yt_client, tmp_unique_geo_id_table, tmp_addr_table)
        geo_log_utils.enhance_with_geobase(yt_client, yql_token, tmp_addr_table, tmp_addr_table)
        yt_client.run_sort(tmp_addr_table, sort_by=key)

        tables = [tmp_addr_table, tmp_upscaled_table]
        general.unite_tables(yt_client, tables, output_table, key)
