#!/usr/bin/python
# -*- coding: utf-8 -*-

# Обновление банщика в динамиках на основе данных из dyn_bs_log.

import re
import datetime

import yt.wrapper as yt
import yt.wrapper.schema as yt_schema
import yandex.type_info.typing as ti

import sandbox.projects.irt.common
import irt.common.cnormalizer as cnormalizer
import irt.common.yt
import irt.logging

logger = irt.logging.getLogger(irt.logging.BANNERLAND_PROJECT, __name__)

DYN_BS_LOG_DIR = '//home/bannerland/logs/dyn_bs_log'
DYN_BS_LOG_COLUMNS = ['Phrase', 'TypeID', 'OrderID', 'ContextType', 'CounterType', 'Timestamp', 'PageOptions']
HITS_TABLE = '//home/broadmatching/cdict_generation/counts_full'


def get_bad_dyn_phrases_table(release_type):
    return f'//home/bannerland/{release_type}/data/trashfilter/bad_dyn_phrases'


def get_grouped_bad_dyn_phrases_table(release_type):
    return f'//home/bannerland/{release_type}/data/trashfilter/grouped_bad_dyn_phrases'


class NormalizeBase:
    def __init__(self):
        self.files = ['norm_dict', 'stop_dict']

    def download_files(self, yt_client):
        all_norm_dicts = {
            'norm_dict': '//home/bannerland/data/normalization_dicts/norm_dict',
            'stop_dict': '//home/bannerland/data/normalization_dicts/stop_dict',
        }

        for key in all_norm_dicts:
            logger.info('Download file \'%s\' into \'%s\'', all_norm_dicts[key], key)
            with open(key, 'wb') as f:
                for x in yt_client.read_file(all_norm_dicts[key]):
                    f.write(x)

    def start(self):
        self.normalizer = cnormalizer.Normalizer('norm_dict')
        self.normalizer.load_stop_words('stop_dict', 'ru')

    def norm_phr(self, text, sort=True, uniq=False):
        return self.normalizer.normalize(text, 'ru', uniq, sort)


class NormalizeReducer(NormalizeBase):
    pass


def set_upload_time(table, yt_client):
    irt.common.yt.set_attribute(
        table,
        'upload_time',
        datetime.datetime.utcnow().isoformat(),
        yt_client=yt_client,
    )


class CTRStatMapper:
    def __init__(self, ban_config):
        self._ban_config = ban_config

    def __call__(self, row):
        if (row.get('ContextType', 0) != 7 or
             not row.get('Phrase') or
             not re.search('main-serp', row.get('PageOptions', '')) or
             row.get('TypeID', -1) not in self._ban_config['type_ids']):
            return
        timestamp = int(row.get('Timestamp') or '0')
        timestamp_desc = -timestamp
        yield {
            'OrderID': row['OrderID'],
            'Phrase': row['Phrase'].strip(),
            'TypeID': row['TypeID'],
            'Click': 1 if row.get('CounterType', 0) == 2 else 0,
            'Timestamp': timestamp,
            'TimestampDesc': timestamp_desc,
        }


class CTRStatReducer(NormalizeReducer):
    def __init__(self, ban_config):
        self._ban_config = ban_config
        super(CTRStatReducer, self).__init__()

    def __call__(self, key, rows):
        last_timestamp = None
        previous_timestamp = None
        shows = 0
        clicks = 0
        in_window = False
        for row in rows:
            row_timestamp = row['Timestamp']
            if last_timestamp is None:
                last_timestamp = row_timestamp
                previous_timestamp = row_timestamp

            if previous_timestamp < row_timestamp:
                raise Exception('Timestamps are not ordered')

            shows += 1
            clicks += row['Click']

            if row_timestamp > self._ban_config['ban_window_begin'] and row_timestamp < self._ban_config['ban_window_end']:
                in_window = True

            if (shows >= self._ban_config['type_ids'][row['TypeID']]['min_shows'] and
                 last_timestamp - row_timestamp > self._ban_config['stat_window_size'] and
                 row_timestamp < self._ban_config['ban_window_begin']):
                break

        ctr = 0.0
        if shows:
            ctr = 1.0*clicks/shows

        key_phrase = key['Phrase']

        if shows >= self._ban_config['type_ids'][key['TypeID']]['min_shows']:
            tilda_zero = re.search(r'\s+~0$', key_phrase)
            norm = self.norm_phr(re.sub(r'\s+~0$', '', key_phrase))
            if tilda_zero:
                norm += ' ~0'
            yield {
                'OrderID': key['OrderID'],
                'Phrase': key_phrase,
                'norm': norm,
                'TypeID': key['TypeID'],
                'Shows': shows,
                'Clicks': clicks,
                'CTR': ctr,
                'InWindow': in_window,
            }


class CutoffReducer:
    def __init__(self, ban_config):
        self._ban_config = ban_config

    def __call__(self, key, rows):
        rows = iter(rows)
        first = next(rows)
        allowed_ban_money = 0
        banned_money = 0
        if first['@table_index'] != 0:
            return
        else:
            allowed_ban_money = first['Money'] * self._ban_config['type_ids'][key['TypeID']]['money_fraction']
        for row in rows:

            if row['CTR'] >= self._ban_config['type_ids'][key['TypeID']]['ctr']:
                continue

            need_ban = False

            if row['CTR'] < self._ban_config['type_ids'][key['TypeID']]['hard_ctr']:
                need_ban = True

            if not need_ban and row['InWindow']:
                if banned_money + row['Money'] <= allowed_ban_money:
                    need_ban = True

            if need_ban and row['InWindow']:
                banned_money += row['Money']

            if need_ban:
                yield {
                    'CTR': row['CTR'],
                    'Shows': row['Shows'],
                    'Clicks': row['Clicks'],
                    'Phrase': row['Phrase'],
                    'OrderID': row['OrderID'],
                }


class RightJoinReducer:
    def __call__(self, key, rows):
        rows = iter(rows)
        first = next(rows)
        freq = 0
        if first['@table_index'] == 0:
            freq = first['freq']
        else:
            first['@table_index'] = 0
            first['Money'] = 0
            yield first
        for row in rows:
            row['@table_index'] = 0
            row['Money'] = row['CTR']*freq
            yield row


class MoneyReducer:
    def __call__(self, key, rows):
        money = sum(row['Money'] for row in rows if row['InWindow'])
        yield {
            'OrderID': key['OrderID'],
            'TypeID': key['TypeID'],
            'Money': money,
        }


def get_yt_client(token, pool, proxy):
    config = {
        'token': token,
        'mount_sandbox_in_tmpfs': True,
    }
    if pool:
        config['spec_defaults'] = {'pool': pool}
    return yt.YtClient(proxy=proxy, config=config)


def run(task, yt_client, release_type, stat_days=365, ban_window_days=60):
    output_bad_dyn_phrases = get_bad_dyn_phrases_table(release_type)
    bad_dyn_phrases_schema = yt_schema.TableSchema() \
        .add_column('OrderID', ti.Int64) \
        .add_column('Phrase', ti.String) \
        .add_column('Shows', ti.Uint64) \
        .add_column('Clicks', ti.Uint64) \
        .add_column('CTR', ti.Double)
    output_bad_dyn_phrases = yt.TablePath(output_bad_dyn_phrases, attributes={'schema': bad_dyn_phrases_schema})

    output_grouped_bad_dyn = get_grouped_bad_dyn_phrases_table(release_type)
    grouped_bad_dyn_schema = yt_schema.TableSchema() \
        .add_column('OrderID', ti.String) \
        .add_column('Phrases', ti.String)
    output_grouped_bad_dyn = yt.TablePath(output_grouped_bad_dyn, attributes={'schema': grouped_bad_dyn_schema})

    buffer_days = 7

    # %s counts unix timestamp
    ban_window_begin_ts = int((datetime.datetime.today() - datetime.timedelta(days=ban_window_days + buffer_days)).strftime('%s'))
    ban_window_end_ts = int((datetime.datetime.today() - datetime.timedelta(days=buffer_days)).strftime('%s'))
    current_day_ts = int(datetime.datetime.today().strftime('%s'))

    ctr_stat_begin_table = (datetime.datetime.today() - datetime.timedelta(days=stat_days)).strftime('%Y%m%d')

    ban_config = {
        'stat_window_size': current_day_ts - ban_window_begin_ts,
        'ban_window_begin': ban_window_begin_ts,
        'ban_window_end': ban_window_end_ts,
        'type_ids': {
            1: {
                'min_shows': 150,
                'ctr': 0.04,
                'hard_ctr': 0.01,
                'money_fraction': 0.1,
            },
            2: {
                'min_shows': 500,
                'ctr': 0.005,
                'hard_ctr': 0.001,
                'money_fraction': 0.1,
            }
        }
    }

    with yt_client.TempTable() as ctr_stat, \
            yt_client.TempTable() as ctr_stat_hits, \
            yt_client.TempTable() as campaign_money:

        logger.info('1. Update banner\'s table \'bad_dyn_phrases\'')

        dyn_bs_log_tables = filter(lambda x: re.match('^\\d{8}$', x), yt_client.list(DYN_BS_LOG_DIR))
        ctr_stat_tables = filter(lambda x: x >= ctr_stat_begin_table, dyn_bs_log_tables)
        ctr_stat_tables = list(map(lambda x: yt.TablePath(yt.ypath_join(DYN_BS_LOG_DIR, x), columns=DYN_BS_LOG_COLUMNS), ctr_stat_tables))

        reducer = CTRStatReducer(ban_config)
        reducer.download_files(yt_client)

        yt_client.run_map_reduce(
            CTRStatMapper(ban_config),
            reducer,
            ctr_stat_tables,
            ctr_stat,
            sort_by=['OrderID', 'Phrase', 'TypeID', 'TimestampDesc'],
            reduce_by=['OrderID', 'Phrase', 'TypeID'],
            spec={'max_failed_job_count': 0},
            reduce_local_files=[reducer.files, ],
        )

        yt_client.run_sort(ctr_stat, sort_by=['norm'])

        yt_client.run_reduce(
            RightJoinReducer(),
            [HITS_TABLE+'{norm,freq}', ctr_stat],
            ctr_stat_hits,
            reduce_by=['norm'],
            format=yt.YsonFormat(control_attributes_mode='row_fields'),
        )

        yt_client.run_sort(ctr_stat_hits, sort_by=['OrderID', 'TypeID', 'CTR'])

        yt_client.run_reduce(
            MoneyReducer(),
            ctr_stat_hits,
            yt.TablePath(campaign_money, sorted_by=['OrderID', 'TypeID']),
            reduce_by=['OrderID', 'TypeID'],
        )

        yt_client.run_reduce(
            CutoffReducer(ban_config),
            [campaign_money, ctr_stat_hits],
            output_bad_dyn_phrases,
            reduce_by=['OrderID', 'TypeID'],
            format=yt.YsonFormat(control_attributes_mode='row_fields'),
        )

        yt_client.run_sort(output_bad_dyn_phrases, sort_by=['OrderID'])

        def aggregate_phrases(key, rows):
            phrases = set(row['Phrase'] for row in rows)
            yield {
                'OrderID': str(key['OrderID']),
                'Phrases': ','.join(sorted(phrases))
            }

        yt_client.run_reduce(
            aggregate_phrases,
            output_bad_dyn_phrases,
            output_grouped_bad_dyn,
            reduce_by=['OrderID'],
        )

        yt_client.run_sort(output_grouped_bad_dyn, sort_by=['OrderID'])
        yt_client.run_merge(output_bad_dyn_phrases, output_bad_dyn_phrases, mode='sorted', spec={'merge_by': ['OrderID'], 'combine_chunks': True})
        yt_client.run_merge(output_grouped_bad_dyn, output_grouped_bad_dyn, mode='sorted', spec={'merge_by': ['OrderID'], 'combine_chunks': True})

        set_upload_time(output_bad_dyn_phrases, yt_client)
        set_upload_time(output_grouped_bad_dyn, yt_client)

        logger.info('2. Start preparing file \'bad_dyn_phrases\'...')

        file_bad_dyn_phrases = 'rt-research/broadmatching/work/dynstat/bad_dyn_phrases'
        bad_dyn_data = yt_client.read_table(
            output_grouped_bad_dyn,
            format='<columns=[OrderID;Phrases;]>schemaful_dsv',
            raw=True
        ).read()

        sub_type = 'dyn_stat'
        if release_type != 'stable':
            sub_type = f'_{sub_type}_TEST'

        if task is not None:
            resource = sandbox.projects.irt.common.create_irt_data(task, sub_type, 'file_bad_dyn_phrases для банщика в динамиках', [file_bad_dyn_phrases])
            file_bad_dyn_phrases = resource.filenames[0]
        else:
            file_bad_dyn_phrases = 'bad_dyn_phrases'
            logger.info('Result was written to file: %s', file_bad_dyn_phrases)

        with open(file_bad_dyn_phrases, 'wb') as f:
            f.write(bad_dyn_data)

        if task is not None:
            sandbox.projects.irt.common.do_bmgendict_copy(task, resource)
