# coding: utf8

from __future__ import division

import json
import six

import yt.wrapper as yt
import yt.yson as yson

import bm.yt_tools
import irt.bannerland.options


def get_experimental_option(key, task_type, yt_client=yt):
    node = irt.bannerland.options.get_cypress_config(task_type).root
    return bm.yt_tools.get_attribute(node, key, yt_client, default_value=None)


def get_display_info(bs_info):
    data = bs_info['text']
    for old_key, new_key in {'to': 'go_to', 'from': 'go_from'}.items():
        if old_key in data:
            data[new_key] = data.pop(old_key)
    obsolete = ['params_for_direct', 'name', 'body_for_direct', 'second_title_for_direct']
    for k in obsolete:
        data.pop(k, None)
    return data


def get_image_info_direct_format(avatars, meta=None):
    # В прото-схеме директа в смарт-центрах int32, в width/height - uint64, а здесь у нас int-ы.
    # TODO: (i-gataullin) учесть при переходе на proto в yt или при использовании type_v3
    avatars_list = []
    rename = {k: k.title() for k in ['url', 'width', 'height']}
    for size in sorted(avatars):
        data = avatars[size]
        new_data = {}
        if 'smart-centers' in data:
            new_data['SmartCenters'] = [{k.title(): v for k, v in sc.items()} for sc in data['smart-centers']]
        if 'smart-center' in data:
            new_data['SmartCenter'] = {k.title(): v for k, v in data['smart-center'].items()}
        for old_key, new_key in rename.items():
            if old_key in data:
                new_data[new_key] = data[old_key]
        new_data['Format'] = size
        avatars_list.append(new_data)

    res = {'Images': avatars_list}
    if meta is not None:
        res['MdsMeta'] = json.dumps(meta, ensure_ascii=False)
    return res


def get_model_info(model_card):
    rename = {
        'model_id':  {'name': 'MarketModelID', 'type': 'uint64'},
        'model_hid': {'name': 'MarketCategoryID', 'type': 'uint64'},
        'bl_type_direct_allowed': {'name': 'AdvTypeDirectAllowed', 'type': 'boolean'},
        'bl_phrase_template_id':  {'name': 'BLPhraseTemplateID', 'type': 'uint64'},
    }
    data = {}
    for old_key, new_col in rename.items():
        if old_key in model_card:
            value_type = bm.yt_tools.convert_yt_to_py(new_col['type'], use_yson=True)
            data[new_col['name']] = value_type(model_card[old_key])

    attrs = model_card['attributes']
    market = attrs['market']
    if 'AdvType' in market:
        data['AdvType'] = market.pop('AdvType')

    for name in ['PriceMin', 'PriceAvg', 'PriceMax', 'Rating', 'MaxRating']:
        if name in market:
            value = market[name]
            if isinstance(value, six.string_types):
                value = value.replace(",", ".")
            market[name] = float(value)
    for name in ['RatingCount', 'OfferClass', 'OfferCount', 'IsNew']:
        if name in market:
            market[name] = int(market[name])

    data['Market'] = market

    if attrs.get('location'):
        data['LocationID'] = yson.YsonUint64(attrs['location'])
    return data


class _GetTopReducer(object):

    def __init__(self, key_field, object_field, max_objects, max_objects_dict):
        self.key_field = key_field
        self.object_field = object_field
        self.max_objects = max_objects
        self.max_objects_dict = max_objects_dict

    def __call__(self, key, rows):
        seen = set()
        max_objects_for_key = self.max_objects_dict.get(key[self.key_field], self.max_objects)
        for row in rows:
            if len(seen) >= max_objects_for_key:
                break
            object_id = row[self.object_field]
            if object_id not in seen:
                yield {self.object_field: object_id}
                seen.add(object_id)


def limit_object_count(input,
                       output,
                       key_field,
                       object_field,
                       max_objects,
                       max_objects_dict=None,
                       bucket_field=None,
                       bucket_weights=(),
                       sort_fields=(),
                       preprocess_mapper=None,
                       preprocess_input_fields=None,
                       yt_spec=None,
                       yt_client=yt):
    """
    Filter rows in yt table to limit distinct objects count for each key.

    :param input: input table
    :param output: output table (also will create output.overlimit.keys, output.overlimit.objects)
    :param key_field: key column name
    :param object_field: object column name, contains object_id - some (hashable) value that identifies object
    :param max_objects: object count limit
    :param max_objects_dict: dict {key: limit} for special keys
    :param bucket_field: bucket index column name (may be None if only one bucket)
    :param bucket_weights: array of bucket weights
    :param sort_fields: sort inside group of given key
    :param preprocess_mapper: run before count objects (e.g. to add sort field), defaults to None. Only key_field, object_field and sort_fields are required in this mapper's output
    :param yt_client: subj
    :param yt_spec: subj

    :return: None
    """

    if max_objects_dict is None:
        max_objects_dict = {}
    if yt_spec is None:
        yt_spec = {}
    yt_spec.setdefault('title', 'limit_object_count')

    if bucket_field is None:
        bucket_weights = (1.0, )

    def take_first_and_distribute_by_buckets(key, rows):
        for row in rows:
            bucket_index = row.pop(bucket_field, 0)
            row['@table_index'] = bucket_index
            yield row
            break

    underlimit_objects_tables = ['{}.underlimit.bucket{}'.format(output, idx) for idx in range(len(bucket_weights))]  # will keep them
    first_occurence_tables = ['{}.first_occurence.bucket{}'.format(output, idx) for idx in range(len(bucket_weights))]
    with yt_client.Transaction():
        input_for_preprocess = input
        if preprocess_input_fields is not None:
            input_for_preprocess = yt_client.TablePath(input_for_preprocess, columns=list(preprocess_input_fields))

        get_first_occurence_spec = yt_spec.copy()
        get_first_occurence_spec['title'] += '.get_first_occurence'

        # to select top we need only first occurence of object
        yt_client.run_map_reduce(
            preprocess_mapper,
            take_first_and_distribute_by_buckets,
            input_for_preprocess,
            first_occurence_tables,
            reduce_by=[key_field, object_field],
            reduce_combiner=bm.yt_tools.FirstReducer(),
            sort_by=[key_field, object_field] + list(sort_fields),
            spec=get_first_occurence_spec,
            reduce_output_format=yt.YsonFormat(control_attributes_mode='row_fields'),
        )

        get_overlimit_spec = yt_spec.copy()
        get_overlimit_spec['title'] += '.get_first_occurence'

        sum_weights = sum(bucket_weights)
        for bucket_weight, first_table, underlimit_table in zip(bucket_weights, first_occurence_tables, underlimit_objects_tables):
            coefficient = bucket_weight / sum_weights
            reducer = _GetTopReducer(
                key_field=key_field,
                object_field=object_field,
                max_objects=int(max_objects * coefficient),
                max_objects_dict={k: int(v * coefficient) for k, v in max_objects_dict.items()},
            )
            yt_client.run_map_reduce(
                None,
                reducer,
                first_table,
                underlimit_table,
                reduce_by=[key_field],
                sort_by=[key_field] + list(sort_fields) + [object_field],  # object - for determinism
                spec=get_overlimit_spec,
            )

        def reduce_filter_overlimit(key, rows):
            good_object = False
            for row in rows:
                if not row.get(key_field):
                    # this is object table, it goes first because None is first in sort_by
                    good_object = True
                    continue
                # this is main (input) table
                if good_object:
                    yield row
                else:
                    break

        filter_overlimit_spec = yt_spec.copy()
        filter_overlimit_spec['title'] += '.filter_overlimit'
        saved_attrs = {attr: bm.yt_tools.get_attribute(input, attr, yt_client) for attr in ['optimize_for', 'schema']}
        for col in saved_attrs['schema']:
            col.pop('sort_order', None)

        yt_client.run_map_reduce(
            None,
            reduce_filter_overlimit,
            underlimit_objects_tables + [input],
            yt_client.TablePath(output, attributes=saved_attrs),
            reduce_by=[object_field],
            sort_by=[object_field, key_field],
            spec=filter_overlimit_spec,
        )
