# coding: utf-8

import logging
import random
import time

import yt.wrapper as yt

import irt.bannerland.options
from bm.yt_tools import get_attribute

import irt.common.yt
from irt.monitoring.solomon import sensors as solomon_sensors

from .full_state import FSWorker


@yt.with_context
class DeltasReducer:
    def __init__(self, unsignif_fields, resend_by_begin_time_ratio=0, force_resend_percent=0, force_resend_percent_step=10, force_resend_done=1):
        self.unsignif_fields = set(unsignif_fields)
        self.resend_by_begin_time_ratio = resend_by_begin_time_ratio
        self.force_resend_percent = force_resend_percent
        self.force_resend_percent_step = force_resend_percent_step
        self.force_resend_done = force_resend_done

    def __call__(self, key, rows, context):
        old_row = None
        new_row = None
        for row in rows:
            if context.table_index == 0:
                old_row = row
            else:
                new_row = row

        banner_to_delete = False
        banner_to_create = False

        if new_row is not None and old_row is None:
            # New banner
            if new_row.get('InGrut'):
                # Order already in GRUT not sending create banner
                return
        elif new_row is not None and old_row is not None:
            # Change banner
            if new_row.get('InGrut') and not old_row.get('InGrut'):
                # New version in grut
                banner_to_delete = True
            elif not new_row.get('InGrut') and old_row.get('InGrut'):
                # Old version in grut
                banner_to_create = True
            elif new_row.get('InGrut') and old_row.get('InGrut'):
                # New and old version in grut
                return

        if new_row is None or banner_to_delete:
            del_row = dict(key)
            del_row['Deleted'] = True
            del_row['@table_index'] = 1
            yield del_row
            return

        new_row['@table_index'] = 0

        if old_row is None:
            new_row['DeltaSource'] = {'type': 'new'}
            yield new_row
            return

        check_fields = set(new_row) - self.unsignif_fields - set(['@table_index'])
        diff_fields = [f for f in check_fields if old_row.get(f) != new_row.get(f)]

        old_bannerland_begin_time = old_row.get('BannerlandBeginTime', 0)
        new_bannerland_begin_time = new_row.get('BannerlandBeginTime', 0)

        if old_bannerland_begin_time and new_bannerland_begin_time and new_bannerland_begin_time - old_bannerland_begin_time > 48 * 3600:
            if random.uniform(0, 1) < self.resend_by_begin_time_ratio:
                diff_fields.append('BannerlandBeginTime')

        if banner_to_create:
            diff_fields.append('InGrut')

        if not diff_fields:
            if not self.force_resend_done and self.force_resend_percent - self.force_resend_percent_step <= new_row['BannerID'] % 100 < self.force_resend_percent:
                diff_fields.append('__FORCE_SEND__')

        if diff_fields:
            new_row['DeltaSource'] = {'type': 'diff', 'fields': diff_fields}
            yield new_row


class DeltasFSWorker(FSWorker):
    def validate_deltas_tables(self, fs_dir, tables):
        yt_client = self.yt_client

        if get_attribute(fs_dir, 'not_validate_deltas', yt_client, None):
            return

        deltas_count = 0
        for table_type in tables:
            for delta_type in tables[table_type]:
                table = tables[table_type][delta_type]
                delta_row_count = yt_client.row_count(table)
                logging.info('found %d rows in table %s', delta_row_count, table)
                deltas_count += delta_row_count

        limit = irt.bannerland.options.get_option('deltas_validate_up_limit')
        if deltas_count > limit:
            raise RuntimeError('validate_deltas_tables failed, found {} rows, limit {} rows'.format(deltas_count, limit))

    def do_work(self, fs_dir):
        version = 'v1'
        yt_client = self.yt_client

        fs_final = yt.ypath_join(fs_dir, 'final')
        prev_fs_final = yt.ypath_join(fs_dir, 'prev_fs_dir', 'final')

        yt_root = irt.bannerland.options.get_cypress_config(self.task_type).root

        if self.task_type == 'dyn':
            all_columns = irt.bannerland.options.get_option('dyn_result_columns')
        else:
            all_columns = irt.bannerland.options.get_option('perf_result_columns')

        unsignif = [col['name'] for col in all_columns if col.get('dont_check_delta')]
        unsignif += ['Avatars', 'Images']  # renamed {avatars} field, remove after CS switchs to caesar

        delta_tables = {}
        for table_type in ['banners']:
            resend_by_begin_time_ratio = irt.common.yt.get_attribute(yt_root, '{}__resend_by_begin_time_ratio'.format(table_type), yt_client, default=0)
            force_resend_percent = irt.common.yt.get_attribute(yt_root, 'force_resend_percent', yt_client, default=0)
            force_resend_percent_step = irt.common.yt.get_attribute(yt_root, 'force_resend_percent_step', yt_client, default=10)
            force_resend_done = irt.common.yt.get_attribute(yt_root, 'force_resend_done', yt_client, default=1)
            tables = {delta_type: yt.ypath_join(fs_dir, 'delta_{}_{}'.format(table_type, delta_type)) for delta_type in ['diff', 'del']}
            src_tables = [yt.ypath_join(prev_fs_final, table_type), yt.ypath_join(fs_final, table_type)]

            src_schema = yt_client.get_attribute(src_tables[1], 'schema')  # diff consists of new rows
            key_columns = [col for col in src_schema if col.get('sort_order')]
            del_schema = key_columns + [{'name': 'Deleted', 'type': 'boolean'}]

            diff_schema = [col.copy() for col in src_schema]
            for col in diff_schema:
                col.pop('sort_order', None)
            diff_schema.append({'name': 'DeltaSource', 'type': 'any'})

            yt_client.run_reduce(
                DeltasReducer(unsignif_fields=unsignif, resend_by_begin_time_ratio=resend_by_begin_time_ratio,
                              force_resend_percent=force_resend_percent, force_resend_percent_step=force_resend_percent_step,
                              force_resend_done=force_resend_done),
                src_tables,
                [
                    yt.TablePath(tables['diff'], attributes={'schema': diff_schema}),
                    yt.TablePath(tables['del'], attributes={'schema': del_schema})
                ],
                reduce_by=[col['name'] for col in key_columns],
                output_format=yt.YsonFormat(control_attributes_mode='row_fields'),
                spec={"job_count": 1000, "auto_merge": {"mode": "relaxed"}},
            )
            delta_tables[table_type] = tables

        self.validate_deltas_tables(fs_dir, delta_tables)
        cypress_conf = self.get_cypress_config()
        export_dir = yt.ypath_join(cypress_conf.get_path('deltas_export'), version)
        utc_seconds_now = int(time.time())
        with yt_client.Transaction():
            for table_type in ['banners']:
                for idx, delta_type in enumerate(['diff', 'del']):
                    timestamp = utc_seconds_now + idx
                    dst_table = yt.ypath_join(export_dir, table_type, '{}.{}'.format(timestamp, delta_type))
                    yt_client.move(delta_tables[table_type][delta_type], dst_table)
                    yt_client.link(dst_table, delta_tables[table_type][delta_type])
                    yt_client.set_attribute(dst_table, 'DeltaTimestamp', timestamp)  # caesar reader will sort by this attr
                    yt_client.set_attribute(dst_table, 'expiration_time', 1000 * (utc_seconds_now + 14 * 24 * 3600))  # use 14 days as expiration timedelta (in milliseconds)

                    solomon_client = solomon_sensors.SolomonAgentSensorsClient()
                    rows_count = yt_client.row_count(delta_tables[table_type][delta_type])
                    solomon_client.push_single_sensor(
                        cluster="yt_{}".format(self.yt_cluster),
                        service="bannerland_yt",
                        sensor="deltas_rows_count",
                        labels={
                            'task_type': self.task_type,
                            'table_type': table_type,
                            'delta_type': delta_type,
                        },
                        value=rows_count,
                    )
        if force_resend_percent >= 100:
            irt.common.yt.set_attribute(yt_root, 'force_resend_done', 1, yt_client)
        elif not force_resend_done:
            irt.common.yt.set_attribute(yt_root, 'force_resend_percent', force_resend_percent + force_resend_percent_step, yt_client)
