# coding=utf-8

import logging
import re
import datetime

import yt.wrapper as yt

import sandbox.projects.release_machine.core.const as rm_const

import irt.common.yt

MAX_SHARD = 32

SOURCE_LINKS = '//home/bannerland/data/dyn_phrase_sources/raw_links'


def get_filtered_links_table(release_type):  # TODO(optozorax) move out this to the common module
    return f'//home/bannerland/{release_type}/data/dse/preparing/filtered_links'


def get_source_links(release_type):
    if release_type == rm_const.ReleaseStatus.stable:
        return '//home/bannerland/data/dyn_phrase_sources/raw_links'
    else:
        return '//home/bannerland/testing/data/dyn_phrase_sources/raw_links'


def set_upload_time(table, yt_client):
    irt.common.yt.set_attribute(table, 'upload_time', datetime.datetime.utcnow().isoformat(), yt_client=yt_client)


def inner_join_reducer(_, rows):
    rows = iter(rows)
    first = next(rows)
    if first['@table_index'] != 0:
        return
    for row in rows:
        row.update(first)
        yield row


filter_longword_re = re.compile(r'\w{15,}', re.UNICODE)


class FilterShardMapper:
    def __init__(self):
        self.index = 0

    def __call__(self, row):
        # В 'links' могут просочиться битые данные, из-за которых мы не хотим падать
        try:
            if filter_longword_re.search(row['Text']):
                return

            self.index += 1
            yield {
                'Domain': row['Domain'],
                'TargetUrl': row['TargetUrl'],
                'Text': row['Text'],
                'Shard': self.index % MAX_SHARD,
            }
        except Exception as e:
            logging.exception(e)


def count_mapper(row):
    yield {
        'Domain': row['Domain'],
        'Text': row['Text'],
        'Shard': row['Shard'],
        'count': 1,
    }


def count_reducer(key, rows):
    count = 0
    shards = set()
    for row in rows:
        count += row['count']
        shards.add(row['Shard'])
    for shard in shards:
        yield {
            'Domain': key['Domain'],
            'Text': key['Text'],
            'Shard': shard,
            'count': count
        }


def count_reduce_combiner(key, rows):
    shards = {}
    for row in rows:
        if row['Shard'] not in shards:
            shards[row['Shard']] = 0
        shards[row['Shard']] += row['count']
    for shard in shards:
        yield {
            'Domain': key['Domain'],
            'Text': key['Text'],
            'Shard': shard,
            'count': shards[shard],
        }


def min_reducer(key, rows):
    min_count = float('inf')
    texts = set()
    for row in rows:
        if min_count > row['count']:
            min_count = row['count']
            texts = {row['Text']}
        elif min_count == row['count']:
            texts.add(row['Text'])
    for text in texts:
        yield {
            'Domain': key['Domain'],
            'TargetUrl': key['TargetUrl'],
            'Text': text
        }


def get_yt_client(token, pool, proxy):
    config = {
        'token': token,
        'mount_sandbox_in_tmpfs': True,
    }
    if pool:
        config['spec_defaults'] = {'pool': pool}
    return yt.YtClient(proxy=proxy, config=config)


def run(yt_client, release_type):
    output_filtered_links = get_filtered_links_table(release_type)

    with yt_client.Transaction():
        with yt_client.TempTable() as sharded_links, \
                yt_client.TempTable() as counts, \
                yt_client.TempTable() as joined_counts_links, \
                yt_client.TempTable() as filtered_links_unsorted:
            yt_client.run_map(
                FilterShardMapper(),
                get_source_links(release_type),
                sharded_links
            )

            yt_client.run_map_reduce(
                count_mapper,
                count_reducer,
                sharded_links,
                counts,
                reduce_by=['Domain', 'Text'],
                reduce_combiner=count_reduce_combiner,
                format=yt.YsonFormat(control_attributes_mode='row_fields'),
            )
            yt_client.run_sort(sharded_links, sort_by=['Domain', 'Text', 'Shard'])
            yt_client.run_sort(counts, sort_by=['Domain', 'Text', 'Shard'])

            yt_client.run_reduce(
                inner_join_reducer,
                [counts, sharded_links],
                joined_counts_links,
                reduce_by=['Domain', 'Text', 'Shard'],
                format=yt.YsonFormat(control_attributes_mode='row_fields'),
            )
            yt_client.run_sort(joined_counts_links, sort_by=['Domain', 'TargetUrl'])

            yt_client.run_reduce(
                min_reducer,
                joined_counts_links,
                filtered_links_unsorted,
                reduce_by=['Domain', 'TargetUrl'],
                format=yt.YsonFormat(control_attributes_mode='row_fields'),
            )
            yt_client.run_sort(filtered_links_unsorted, output_filtered_links, sort_by=['Domain', 'TargetUrl'])

        set_upload_time(output_filtered_links, yt_client)
