#!/usr/bin/env python
# coding=utf-8

from bm.yt_tools import NormalizeBase, get_cdict_generation_params

COMMON_PARAMS = get_cdict_generation_params()


def prepare_orig_file(key, rows):
    res = dict(key)
    res['freq_query'] = -1
    for row in rows:
        if int(row['freq_query']) > res['freq_query']:
            res['freq_query'] = int(row['freq_query'])
            res['orig_text'] = row['orig_text']
    if len(res['norm']) > 0 and res['freq_query'] >= COMMON_PARAMS['min_query_frequency'] and \
            len(res['orig_text'].split(' ')) <= COMMON_PARAMS['max_query_words']:
        yield res


def aggr_func_counts(key, rows):
    res = dict(key)
    res['freq'] = res['freq_query'] = 0
    regions = dict()
    for row in rows:
        res['freq'] += int(row['freq'])
        res['freq_query'] += int(row['freq_query'])
        for record in row['regions'].split(' '):
            (r, c) = map(int, record.split(':'))
            if r in regions:
                regions[r] += c
            else:
                regions[r] = c
    res['regions'] = ' '.join(map(lambda r: str(r)+':'+str(regions[r]), sorted(regions.keys())))
    yield res


class NormMapper(NormalizeBase):
    def __init__(self, is_mobile=False):
        super(NormMapper, self).__init__()
        if is_mobile:
            self.freq_key_name = "MobileHits"
            self.freq_region_key_name = "MobileRegionHits"
        else:
            self.freq_key_name = "Hits"
            self.freq_region_key_name = "RegionHits"

    def __call__(self, r):
        text = r["OrigSanitized"]
        freq = r[self.freq_key_name]
        if freq == 0:
            return
        regions = r[self.freq_region_key_name]

        orig_text = text
        words = text.split(' ')

        freq_query = freq if len(words) <= COMMON_PARAMS['max_query_words'] else 0
        words = words[:COMMON_PARAMS['max_query_words']]
        text = ' '.join(words).replace('!', '')

        was_in_norm = set()
        full_norm_words = list()
        for word in self.norm_phr(orig_text.replace('!', ''), sort=False).split(' '):
            if word not in was_in_norm:
                was_in_norm.add(word)
                full_norm_words.append(word)
        full_norm = ' '.join(full_norm_words)

        norms = sorted(filter(lambda s: len(s) > 0, list({s for s in self.norm_phr(text, sort=False).split(' ')})))
        if len(norms) > 0:
            yield {'norm': ' '.join(norms), 'freq': freq, 'freq_query': freq_query, 'regions': regions,
                   'orig_text': orig_text, 'full_norm': full_norm}


def generate_norms(yt_client, hits_dir, dest_table, dest_table_full, is_mobile=False, dest_table_orig=None):
    # get hits table - last table in directory
    source_table = ''
    for source_table in yt_client.search(hits_dir, node_type=["table"]):
        pass
    assert source_table != ""

    mapper = NormMapper(is_mobile=is_mobile)

    with yt_client.TempTable() as tmp_table:
        yt_client.run_map(
            mapper,
            source_table,
            tmp_table,
            yt_files=mapper.yt_files,
            spec={"data_size_per_job": int(0.25 * 1024 ** 3)}
        )
        yt_client.run_sort(tmp_table, sort_by=["norm", "full_norm"])

        yt_client.run_reduce(
            aggr_func_counts,
            tmp_table, dest_table_full,
            reduce_by=["norm", "full_norm"],
        )
        yt_client.run_sort(dest_table_full, sort_by=["norm"], spec={"combine_chunks": True})

        yt_client.run_reduce(
            aggr_func_counts,
            tmp_table, dest_table,
            reduce_by=["norm"],
        )
        yt_client.run_sort(dest_table, sort_by=["norm"], spec={"combine_chunks": True})

        if dest_table_orig:
            yt_client.run_reduce(
                prepare_orig_file,
                tmp_table, dest_table_orig,
                reduce_by=["norm"],
            )
            yt_client.run_sort(dest_table_orig, sort_by=["norm"], spec={"combine_chunks": True})
