#!/usr/bin/env python
# coding=utf-8

from bm.yt_tools import get_cdict_generation_params

# defaults
COMMON_PARAMS = get_cdict_generation_params()


class SubphrasesExtractor():
    def __init__(self):
        self.masks = [0] + [self.generate_masks(i) for i in xrange(1, COMMON_PARAMS['max_generation_words'] + 1)]

    def generate_masks(self, size):
        return sorted(map(lambda x: sorted(x), self.generate_masks_rec([i for i in xrange(size)])))

    def generate_masks_rec(self, m):
        if len(m) == 1:
            return [[m[0]]]
        subm = self.generate_masks_rec(m[1:])
        return filter(lambda l: len(l) <= COMMON_PARAMS['max_subphrase_words'],
                      subm + [[m[0]]] + map(lambda w: w + [m[0]], subm))

    def get_all_subphrases(self, words):
        cnt_words = len(words)
        for mask in self.masks[cnt_words]:
            subwords = []
            for i in mask:
                subwords.append(words[i])
            if len(subwords):
                yield ' '.join(subwords)

    def __call__(self, r):
        norm_words = r['full_norm'].split(' ')

        if len(norm_words) > COMMON_PARAMS['max_generation_words']:
            norm_words = norm_words[:COMMON_PARAMS['max_generation_words']]
        norm_words = sorted(norm_words)

        # emit ~0
        if len(r['norm']) < COMMON_PARAMS['max_phrase_length']:
            yield {"norm": (r['norm'] + ' ~0').decode('utf8'), "freq_query": r['freq_query'], "freq": r['freq'],
                   "regions": r['regions']}

        # emit all subphrases
        total_len = len(r['full_norm'])
        for subphrase in self.get_all_subphrases(norm_words):
            subph_len = len(subphrase)
            if subph_len < COMMON_PARAMS['max_phrase_length']:
                yield {"norm": subphrase, "freq_query": r['freq_query'] if total_len == subph_len else 0,
                       "freq": r['freq'], "regions": r['regions']}


def stat_aggregator(key, rows):
    res = dict(key)
    res['freq_query'] = res['freq'] = 0
    regions = dict()
    for row in rows:
        res['freq_query'] += int(row['freq_query'])
        res['freq'] += int(row['freq'])
        for record in row['regions'].split(' '):
            (r, h) = record.split(':')
            h = int(h)
            if r in regions:
                regions[r] += h
            else:
                regions[r] = h

    res['regions'] = ' '.join(map(lambda r: str(r)+':'+str(regions[r]), sorted(regions.keys())))
    if res['freq'] > 1:
        yield res


def mapper_counts(r):
    if r['freq'] >= COMMON_PARAMS['min_phrase_frequency']:
        yield {'norm': r['norm'], 'freq': r['freq']}


def mapper_geo(r):
    if r['freq'] >= COMMON_PARAMS['min_phrase_frequency']:
        regions = filter(lambda c: c[1] > COMMON_PARAMS['min_phrase_frequency'], map(lambda r: map(int, r.split(':')), r['regions'].split(' ')))
        if len(regions) > 0:
            yield {'norm': r['norm'], 'regions': ' '.join(map(lambda c: str(c[0])+':'+str(c[1]), regions))}


def generate_counts(yt_client, src_table, dst_table_counts_full, dst_table_counts, dst_table_geo=None):
    with yt_client.TempTable() as tmp_table:
        yt_client.run_map(
            SubphrasesExtractor(),
            src_table,
            tmp_table
        )

        yt_client.run_map_reduce(
            None,
            stat_aggregator,
            tmp_table,
            dst_table_counts_full,
            reduce_by=["norm"]
        )
        yt_client.run_sort(dst_table_counts_full, sort_by=["norm"], spec={"combine_chunks": True})

    # filter_counts было раньше
    yt_client.run_map(
        mapper_counts,
        dst_table_counts_full,
        dst_table_counts,
    )
    yt_client.run_sort(dst_table_counts, sort_by=["norm"], spec={"combine_chunks": True})

    if dst_table_geo:
        yt_client.run_map(
            mapper_geo,
            dst_table_counts_full,
            dst_table_geo,
        )
        yt_client.run_sort(dst_table_geo, sort_by=["norm"], spec={"combine_chunks": True})


def merge_counts(yt_client, src_tables, dst_table):
    with yt_client.TempTable() as tmp_table:
        yt_client.run_map_reduce(
            None,
            stat_aggregator,
            src_tables,
            tmp_table,
            reduce_by=["norm"]
        )
        yt_client.run_sort(tmp_table, dst_table, sort_by=["norm"], spec={"combine_chunks": True})
