# coding: utf-8

import bisect
import collections
import json
import hashlib
import struct

import yt.wrapper as yt

import bannerland.archive_workers.full_state
import bannerland.common
import bm.yt_tools
import irt.bannerland.options


# сначала берём свежие карманы, т.к. старые могли уже отключиться (в умном ранжировании это будет учитываться автоматом)
TIME_SHUFFLE_VALUE = 600000
BUCKET_FIELD = '_BucketIndex'
SORT_FIELD = '_SelectTopMinusScore'
PREDICTIONS_FIELD = 'SelectionRankPredictions'
SPLIT_SALT = 'SelectionRank'
DEFAULT_SCORE = -10


def md5int(s):
    # input: utf-8 string
    try:
        s.decode('utf8')
    except UnicodeDecodeError:
        raise Exception("input string for md5int must be UTF8")

    a = struct.unpack(">LLLL", hashlib.md5(s).digest())
    return (a[1] ^ a[3]) << 32 | (a[0] ^ a[2])


class PreprocessMapper(object):
    def __init__(self, buckets, bannerid_throttle_percent=100):
        self.buckets = buckets
        self.bucket_thresholds = self._get_thresholds(buckets)
        self.bannerid_throttle_percent = bannerid_throttle_percent

    def _get_thresholds(self, buckets):
        percent_sum = 0
        thresholds = []
        for bucket in buckets:
            percent_sum += bucket['percent']
            thresholds.append(percent_sum)
        if percent_sum != 100:
            raise ValueError('Bad bucket percentage!')
        return thresholds

    def _get_bucket_index(self, BannerID):
        remainder = md5int('{}{}'.format(BannerID, SPLIT_SALT)) % 100
        return bisect.bisect_right(self.bucket_thresholds, remainder)

    def __call__(self, row):
        if row['BannerID'] % 100 >= self.bannerid_throttle_percent:
            return

        bucket_index = self._get_bucket_index(row['BannerID'])
        bucket_info = self.buckets[bucket_index]
        if bucket_info['type'] == 'update_time':
            # чтобы таски из соседних карманов не перезатирали друг друга по очереди и попали в стейт равномерно, размазываем UpdateTime
            score = row['UpdateTime'] + (TIME_SHUFFLE_VALUE/2) - (row['BannerID'] % TIME_SHUFFLE_VALUE)
        elif bucket_info['type'] == 'prediction':
            model_id = bucket_info.get('ModelId', bucket_info.get('ModelID'))
            predictions = json.loads(row.get(PREDICTIONS_FIELD)) if row.get(PREDICTIONS_FIELD) else []
            predictions = [pred for pred in predictions if pred.get('ModelId', pred.get('ModelID')) == model_id]
            if len(predictions) == 1:
                score = predictions[0]['Prediction']
            else:
                score = DEFAULT_SCORE
        else:
            raise ValueError('Unknown bucket type!')

        yield {'OrderID': row['OrderID'], 'BannerID': row['BannerID'], BUCKET_FIELD: bucket_index, SORT_FIELD: -score}


class SelectTopFSWorker(bannerland.archive_workers.full_state.FSWorker):
    def __init__(self, input_name, output_name, **kwargs):
        self.input_name = input_name
        self.output_name = output_name
        self.bannerid_throttle_percent = kwargs.pop('bannerid_throttle_percent', 100)
        super(SelectTopFSWorker, self).__init__(**kwargs)

    def do_work(self, fs_dir):
        input = yt.ypath_join(fs_dir, self.input_name)
        output = yt.ypath_join(fs_dir, self.output_name)
        limit_options = irt.bannerland.options.get_option('banner_count_limits')

        # config for buckets; example:
        # [{'type': 'prediction', 'percent': 10, 'ModelID': 1}, {'type': 'update_time', 'percent': 90}]
        buckets_attr = 'SelectTopFSWorker.buckets'
        buckets = bannerland.common.get_experimental_option(buckets_attr, task_type=self.task_type, yt_client=self.yt_client)
        if buckets is None:
            buckets = [{'type': 'update_time', 'percent': 100}]

        # OrderID is required for top selection after preprocess
        preprocess_input_fields = ['OrderID', 'BannerID', 'UpdateTime', PREDICTIONS_FIELD]

        domain_limits = limit_options['by_orig_domain']
        order_limits = limit_options['by_order_id'].copy()
        clients_limits = limit_options['default_by_client_id']
        default_limit = limit_options['default']

        domain_limits_expanded = self.expand_domain_limits(input, domain_limits)
        clients_limits_expanded = self.expand_clients_limits(input, clients_limits)

        for order in clients_limits_expanded:
            order_limits[order] = min(clients_limits_expanded[order], order_limits.get(order, default_limit))

        for order in domain_limits_expanded:
            order_limits[order] = min(domain_limits_expanded[order], order_limits.get(order, default_limit))

        bannerland.common.limit_object_count(
            input=input,
            output=output,
            key_field='OrderID',
            object_field='BannerID',
            bucket_field=BUCKET_FIELD,
            bucket_weights=tuple(bucket['percent'] for bucket in buckets),
            preprocess_mapper=PreprocessMapper(buckets, self.bannerid_throttle_percent),
            preprocess_input_fields=preprocess_input_fields,
            max_objects=default_limit,
            max_objects_dict=order_limits,
            sort_fields=[SORT_FIELD],
            yt_client=self.yt_client,
            yt_spec={'title': 'select_top'},
        )
        self.yt_client.set(yt.ypath_join(fs_dir, '@' + buckets_attr), buckets)

    def expand_clients_limits(self, input, client_limits):
        limit_field = 'ClientID'
        with self.yt_client.TempTable() as uniq_orders_for_clients:
            def client_filter_mapper(row):
                client_id = row[limit_field]
                if client_id in client_limits and row['OrderID'] > client_limits[client_id].get('start_from', 0):
                    yield row
            self.yt_client.run_map_reduce(
                client_filter_mapper,
                bm.yt_tools.FirstReducer(),
                yt.TablePath(input, columns=[limit_field, 'OrderID']),
                uniq_orders_for_clients,
                reduce_by=[limit_field, 'OrderID']
            )
            return {
                row['OrderID']: client_limits[row[limit_field]]['limit'] for row in self.yt_client.read_table(uniq_orders_for_clients)
            }

    def expand_domain_limits(self, input, domain_limits):
        with self.yt_client.TempTable() as uniq_bannerids, self.yt_client.TempTable() as orderid_counts_table:

            def domain_filter_mapper(row):
                if row['OrigDomain'] in domain_limits:
                    yield row

            self.yt_client.run_map_reduce(
                domain_filter_mapper,
                bm.yt_tools.FirstReducer(),
                yt.TablePath(input, columns=['OrigDomain', 'OrderID', 'BannerID']),
                uniq_bannerids,
                reduce_by=['OrigDomain', 'OrderID', 'BannerID']
            )

            def order_count_reducer(key, rows):
                yield {
                    'OrigDomain': key['OrigDomain'],
                    'OrderID': key['OrderID'],
                    'OrderCount': sum(1 for row in rows),
                }

            self.yt_client.run_map_reduce(
                None,
                order_count_reducer,
                uniq_bannerids,
                orderid_counts_table,
                reduce_by=['OrigDomain', 'OrderID'],
            )

            domain_orderid_counts = collections.defaultdict(lambda: collections.defaultdict(int))
            domain_counts = collections.defaultdict(int)
            for row in self.yt_client.read_table(orderid_counts_table):
                domain_orderid_counts[row['OrigDomain']][row['OrderID']] = row['OrderCount']
                domain_counts[row['OrigDomain']] += row['OrderCount']

        domain_limits_expanded = {}

        for domain in domain_orderid_counts:
            if domain_counts[domain] <= domain_limits[domain]:
                continue
            for order in domain_orderid_counts[domain]:
                # подрезаем все ордеры на такую долю, чтобы в сумме они укладывались в лимит
                domain_limits_expanded[order] = round(domain_limits[domain] * (1.0 * domain_orderid_counts[domain][order] / domain_counts[domain]))
        return domain_limits_expanded
