# -*- coding: utf-8 -*-

from __future__ import division
import json
import typing

import yt.wrapper as yt

from bannerland.yql.tools import list_yql_result, do_yql, get_client as get_yql_client
import irt.common.yt
import irt.common.counterdict
import irt.logging


logger = irt.logging.getLogger(irt.logging.BANNERLAND_PROJECT, __name__)


TASK_TYPE_CONFIG = {
    'perf': {
        'delete_tasks_less': 0.36,      # avg 0.037
        'phrases_per_offer_more': 1.6,  # avg 9.8
        'sd_phrases_per_offer_more': {
            804002: 1.1,   # native       avg 5.5
            804050: 0.3,   # dse          avg 1.4
        }
    },
    'dyn': {
        'delete_tasks_less': 0.36,       # avg 0.072
        'phrases_per_offer_more': 0.42,  # avg 2.1
        'sd_phrases_per_offer_more': {
            700001: 0.03,  # native       last bad pocket: 0.087
            700050: 0.11,  # dse          avg 0.56
        }
    },
}

MIN_BANNERS_TO_COMPARE = 1000000


SIMDISTANCE_COUNTER_NAME = 'simdistance'
TEMPLATE_INFO_COUNTER_NAME = 'template_info'
ORDERID_COUNTER_NAME = 'orderid'
VALIDATE_BANNERS_TABLE_CONFIG = {
    'sub_ratios': {
        SIMDISTANCE_COUNTER_NAME: {
            'default': 0.2,
            '800333': 1,
        },
        TEMPLATE_INFO_COUNTER_NAME: {
            'feed\t\tno_title__tmpl': 1.0,
            'feed\tdse\tbase': 0.5,
            'feed\tsimple\t': 2,
            'dse\tsimple\t': 2,
            'crawler\tsimple\t': 2,
            'specurl\tsimple\t': 2,
            'site\tsimple\t': 2,
            'links\tsimple\t': 2,
            'site\tnative\tbase': 0.5,
            'default': 0.25,
        },
    },
    'uniq_ratios': {
        ORDERID_COUNTER_NAME: 0.2,
    },
    'row_count_ratio': 0.2,
}


class BannerStatAggregateMapper(object):
    """Mapper class for aggregating counters"""
    def start(self):
        self.counters = irt.common.counterdict.CounterDict()

    def __call__(self, row):
        for phrase in row.get('BroadPhrases', []):
            self.counters[SIMDISTANCE_COUNTER_NAME][str(phrase.get('SimDistance', 0))] += 1
        template_info = '\t'.join(row.get('BLBannerDetails', {}).get(x, '') for x in ['offer_source', 'title_source', 'title_template_type'])

        self.counters[TEMPLATE_INFO_COUNTER_NAME][template_info] += 1
        self.counters[ORDERID_COUNTER_NAME][str(row['OrderID'])] = 1

    def finish(self):
        yield {
            'counters': self.counters
        }


def get_banners_table_counters(table, yt_client=None):
    # type: (typing.Union[str, yt.TablePath], typing.Optional[yt.YtClient]) -> irt.common.counterdict.CounterDict
    """
    Calculate statistics from banner table

    :param table: Node path
    :param yt_client: YtClient for communicating with server. None for default client
    :return: CounterDict with statistics
    """
    yt_client = yt_client or yt
    cached_counters = irt.common.yt.get_attribute(table, 'banners_table_counters', yt_client, default=None)
    if cached_counters is not None:
        return irt.common.counterdict.CounterDict(cached_counters)
    counters = irt.common.counterdict.CounterDict()
    with yt_client.TempTable() as tmp_result:
        yt_client.run_map(
            BannerStatAggregateMapper(),
            yt.TablePath(table, columns=['BroadPhrases', 'BLBannerDetails', 'OrderID', 'Title']),
            tmp_result,
            spec={
                'auto_merge': {
                    'mode': 'relaxed',
                },
            },
        )
        for row in yt_client.read_table(tmp_result):
            counters += row['counters']
    for counter in VALIDATE_BANNERS_TABLE_CONFIG['uniq_ratios']:
        counters[counter] = {'total_count': len(counters[counter])}
    irt.common.yt.set_attribute(table, 'banners_table_counters', counters, yt_client)
    return counters


def compare_banners_table_counters(new_table, old_table, yt_client=None):
    # type: (typing.Union[str, yt.TablePath], typing.Union[str, yt.TablePath], typing.Optional[yt.YtClient]) -> bool
    """
    Calculates and compares two tables' statistics, according to config

    :param new_table: Node path
    :param old_table: Node path
    :param yt_client: YtClient for communicating with server. None for default client
    :return: True if all checks passed, False otherwise
    """
    yt_client = yt_client or yt
    old_counters = get_banners_table_counters(old_table, yt_client)
    new_counters = get_banners_table_counters(new_table, yt_client)

    logger.info('new table counters: {}'.format(json.dumps(new_counters)))
    logger.info('old table counters: {}'.format(json.dumps(old_counters)))

    success = True

    logger.info('validating sub_ratios')
    for counter_name, counter_config in VALIDATE_BANNERS_TABLE_CONFIG['sub_ratios'].items():
        logger.info('validating %s', counter_name)
        for sub_counter in set(old_counters[counter_name].keys() + new_counters[counter_name].keys()):
            if new_counters[counter_name][sub_counter] < MIN_BANNERS_TO_COMPARE and old_counters[counter_name][sub_counter] < MIN_BANNERS_TO_COMPARE:
                logger.warning('%s %s validation skipped low banners count: %d => %d', counter_name, sub_counter, old_counters[counter_name][sub_counter], new_counters[counter_name][sub_counter])
                continue
            ratio = ((new_counters[counter_name][sub_counter] + 1) / (old_counters[counter_name][sub_counter] + 1)) - 1
            if abs(ratio) > counter_config.get(sub_counter, counter_config['default']):
                success = False
                logger.error('%s %s validation failed: %f', counter_name, str(sub_counter), ratio)
            else:
                logger.info('%s %s ok: %f', counter_name, str(sub_counter), ratio)

    logger.info('validating uniq_ratios')
    for counter_name, counter_value in VALIDATE_BANNERS_TABLE_CONFIG['uniq_ratios'].items():
        logger.info('validating %s', counter_name)
        if new_counters[counter_name]['total_count'] < MIN_BANNERS_TO_COMPARE and old_counters[counter_name]['total_count'] < MIN_BANNERS_TO_COMPARE:
            logger.warning('%s total_count validation skipped due to small banners count: %d => %d', counter_name, old_counters[counter_name]['total_count'], new_counters[counter_name]['total_count'])
            continue
        ratio = ((new_counters[counter_name]['total_count'] + 1) / (old_counters[counter_name]['total_count'] + 1)) - 1
        if abs(ratio) > counter_value:
            success = False
            logger.error('%s validation failed: %f', counter_name, ratio)
        else:
            logger.info('%s ok: %f', counter_name, ratio)

    logger.info('validating row_count')
    ratio = ((yt_client.row_count(new_table) + 1) / (yt_client.row_count(old_table) + 1)) - 1
    if abs(ratio) > VALIDATE_BANNERS_TABLE_CONFIG['row_count_ratio']:
        success = False
        logger.error('row_count validation failed: %f', ratio)
    else:
        logger.info('row_count ok: %f', ratio)

    return success


def validate_pocket(task_type, pocket, yt_client=yt):
    tao_table = pocket + '/tasks_and_offers'
    tasks_table = pocket + '/tasks.final'
    banners_table = pocket + '/generated_banners.final'
    validate(task_type, tao_table, tasks_table, banners_table, yt_client=yt_client)


def validate(task_type, tao_table, tasks_table, banners_table, yt_client=yt):
    if task_type not in TASK_TYPE_CONFIG:
        # temporary fix - should adjust thresholds
        return

    thresholds = TASK_TYPE_CONFIG[task_type]
    yql_client = get_yql_client()
    yt_pool = yt_client.config['spec_defaults'].get('pool')

    tasks_count = yt_client.row_count(tasks_table)
    yql_query = """
                select count(distinct(`product_md5`, `task_id`)) as offers_count from `{input_table}`
            """.format(input_table=tao_table)
    yql_result = list_yql_result(do_yql(yql_client, yql_query, yt_pool=yt_pool))
    offers_count = yql_result[0]['offers_count']
    phrases_count = yt_client.row_count(banners_table)

    if 'delete_tasks_less' in thresholds:
        del_tasks = 0
        for row in yt_client.read_table(tasks_table):
            if row['banners_count'] == 0:
                del_tasks += 1

        del_tasks_ratio = (del_tasks + 2) / (tasks_count + 100)

        logger.warning('deleted tasks: %s of %s, smoothed ratio: %s', del_tasks, tasks_count, del_tasks_ratio)
        if del_tasks_ratio > thresholds['delete_tasks_less']:
            raise Exception('Too many deleted tasks!')

    if 'banners_per_offer_more' in thresholds:
        banners_count = 0
        for row in yt_client.read_table(tasks_table):
            banners_count += row['banners_count']
        banners_per_offer = (banners_count + 10) / (offers_count + 100)
        logger.warning('banners_per_offer: %s banners per %s offers, smoothed ratio: %s', banners_count, offers_count, banners_per_offer)
        if banners_per_offer < thresholds['banners_per_offer_more']:
            raise Exception('Too few banners per offer')

    if 'phrases_per_offer_more' in thresholds:
        phrases_per_offer = (phrases_count + 5000) / (offers_count + 1000)
        logger.warning('phrases: %s per %s offers, smoothed ratio: %s', phrases_count, offers_count, phrases_per_offer)
        if phrases_per_offer < thresholds['phrases_per_offer_more']:
            raise Exception('Too few phrases per offer')

    if 'sd_phrases_per_offer_more' in thresholds:
        yql_query = """
            select SimDistance, count(*) as SimDistance_count from `{input_table}`
            group by SimDistance
        """.format(input_table=banners_table)
        yql_result = list_yql_result(do_yql(yql_client, yql_query, yt_pool=yt_pool))
        sd_count = {row['SimDistance']: row['SimDistance_count'] for row in yql_result}

        for sd in thresholds['sd_phrases_per_offer_more']:
            sd_per_offer = 1.0 * (sd_count[sd] + 3000) / (offers_count + 1000)
            logger.warning('%d phrases: %s per %s offer, smoothed ratio: %s', sd, sd_count[sd], offers_count, sd_per_offer)
            if sd_per_offer < thresholds['sd_phrases_per_offer_more'][sd]:
                raise Exception('Too few phrases per offer from SD %d', sd)

    if 'is_full_retarg' in thresholds:
        yql_query = """
            select *
            from `{input_table}`
            WHERE NOT (Text LIKE 'offerid%');
        """.format(input_table=banners_table)
        yql_result = list_yql_result(do_yql(yql_client, yql_query, yt_pool=yt_pool))
        if yql_result:
            raise Exception('not all phrases retarg')
        logger.warning('is_full_retarg: result OK')
