# -*- coding: utf-8 -*-
import argparse
import hashlib
import os

import yt.wrapper as yt

import bm.yt_tools


class TokensMapper(bm.yt_tools.NormalizeMapper):
    def __init__(self, remainder, count):
        super(TokensMapper, self).__init__(remove_stop_words=False)
        self.remainder = remainder
        self.count = count

    def __call__(self, row):
        out = {}
        out['bid'] = row['bid']
        out['md5_bid'] = hashlib.md5(str(row['bid'])).hexdigest()
        if int(out['md5_bid'][-8:], 16) % self.count == self.remainder:
            tokens = set()
            if row.get('body') is not None:
                words = self.norm_phr(row['body'], sort=False).split()
                for i in range(len(words) - 1):
                    tokens.add('__mw_{}_{}'.format(words[i], words[i + 1]))
                tokens.update(words)
            if row.get('title') is not None:
                words = self.norm_phr(row['title'], sort=False).split()
                for i in range(len(words) - 1):
                    tokens.add('__mw_{}_{}'.format(words[i], words[i + 1]))
                tokens.update(words)
            if row.get('cid') is not None:
                tokens.add('campaign_id_{}'.format(row['cid']))
            if row.get('uid') is not None:
                tokens.add('client_id_{}'.format(row['uid']))
                tokens.add('uid_{}'.format(row['uid']))
            if row.get('domain') is not None:
                domain = row['domain']
                if domain.count('.') >= 2:
                    domain = '.'.join(domain.split('.')[-2:])
                tokens.add('domain_{}'.format(domain.replace('.', '_')))
            if row.get('lang') is not None:
                tokens.add('lang_{}'.format(row['lang']))
            if row.get('minicategs_ids'):
                tokens.update(['categ_' + f for f in row['minicategs_ids'].split(',')])
            else:
                tokens.add('categ_uncategorized')
            if row.get('Flags'):
                tokens.update(['flag_' + f for f in row['Flags'].split(',')])
            if row.get('is_active'):
                tokens.add('active_flag')
            out['tokens'] = ' '.join(sorted(tokens))
            yield out


@yt.with_context
def map_index_pairs(row, context):
    index = context.row_index
    for token in row['tokens'].split():
        yield {'token': token, 'index': index + 1}


MAX_ROW_INDEXES = 10 ** 6


def reduce_index(key, rows):
    # разделяем, чтобы не было слишком длинных строк в YT
    # потом соединяем при подготовке бинарных данных для бендера
    indexes = []
    num = 0
    for k, row in enumerate(rows):
        if k > 0 and k % MAX_ROW_INDEXES == 0:
            yield {'token': key['token'], 'num': num, 'indexes': ' '.join(indexes)}
            num += 1
            indexes = []
        indexes.append(str(row['index']))
    if len(indexes) > 0:
        yield {'token': key['token'], 'num': num, 'indexes': ' '.join(indexes)}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--banners_extended_table', required=True)
    parser.add_argument('--bender_tokens', required=True)
    parser.add_argument('--bender_index', required=True)
    parser.add_argument('--count_shards', required=True)
    parser.add_argument('--shard_index', required=True)
    args = parser.parse_args()
    banners_extended_table = args.banners_extended_table
    bender_tokens = args.bender_tokens
    bender_index = args.bender_index
    count_shards = int(args.count_shards)
    shard_index = int(args.shard_index)

    yt_pool = os.getenv('YT_POOL') or 'catalogia'
    yt_config = yt.default_config.get_config_from_env()
    yt_config['remote_temp_tables_directory'] = '//home/catalogia/tmp'
    yt_config['spec_defaults'] = {'pool': yt_pool}
    yt_config['create_table_attributes'] = {'compression_codec': 'brotli_5'}
    yt_config['mount_sandbox_in_tmpfs'] = True  # словари поднимаются в памяти, чтобы не долбиться в диск
    yt_client = yt.YtClient(config=yt_config)

    with yt_client.TempTable() as tmp_index_pairs:

        # 1) соединяем токены в одну строку, добавляем md5(bid) для воспроизводимой перетасовки
        tokens_mapper = TokensMapper(shard_index - 1, count_shards)
        yt_client.run_map(
            tokens_mapper,
            banners_extended_table,
            bender_tokens,
            yt_files=tokens_mapper.yt_files)

        # 2) сортируем по md5(bid) - имитация случайной перетасовки
        yt_client.run_sort(bender_tokens, sort_by=['md5_bid'])

        # 3) мапим в пары токен - номер (порядковы номер баннера в перетасованной таблице)
        yt_client.run_map(
            map_index_pairs,
            bender_tokens,
            tmp_index_pairs,
            job_io={'control_attributes': {'enable_row_index': True}})

        # 4) сортировка для reduce
        yt_client.run_sort(tmp_index_pairs, sort_by=['token', 'index'])

        # 5) редьюсим по токенам
        yt_client.run_reduce(
            reduce_index,
            tmp_index_pairs,
            bender_index,
            reduce_by=['token'],
            sort_by=['token', 'index'],
            memory_limit=16 * 1024 ** 3)

        # 6) сортируем индекс по токенам
        yt_client.run_sort(bender_index, sort_by=['token', 'num'])


if __name__ == '__main__':
    main()
