# -*- coding: utf-8 -*-
import logging
from collections import OrderedDict
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from operator import itemgetter

import yt.wrapper as yt
from django.conf import settings  # noqa

import travel.avia.admin.lib.yt_helpers as yh
from travel.avia.admin.lib.grouping import group_tree

log = logging.getLogger(__name__)


def calculate_stat_by_partners(begin, end):
    qids_count, percentile_by_partner, records = _calculate_statuses_events(begin, end)

    def preprocess(r):
        r['date_ymdh'] = datetime.strptime(r.pop('date_ymdh'), '%Y-%m-%d %H')
        r['count'] = int(r['count'])
        return r

    records = [preprocess(record) for record in records]

    by_partner = group_tree(records, 'partner.date_ymdh.status')

    return qids_count, percentile_by_partner, {
        partner: sorted(
            gen_pstat(partner_by_hour_statuses),
            key=itemgetter('hour')
        )
        for partner, partner_by_hour_statuses in by_partner.iteritems()
    }


def gen_pstat(by_hour_statuses):
    for hour, statuses in by_hour_statuses.items():
        def safe(status, value_field='count'):
            try:
                return statuses[status][0][value_field]
            except KeyError:
                return 0

        got_reply = safe('got_reply')
        empty = safe('empty')
        timeout = safe('timeout')
        response_error = safe('response_error')
        got_failure = safe('got_failure')

        queries = got_reply + empty + timeout + response_error + got_failure

        yield OrderedDict([
            ('hour', hour),

            ('queries', queries),

            ('good', got_reply),
            ('timing_90pc', safe('got_reply', 'timing_90pc')),

            ('empty', empty),
            ('timeout', timeout),
            ('error', got_failure + response_error),

            # ('redir_fails', 0),  # TODO
        ])


JOB_COUNT = 100


def filter_by_status(status, record):
    if record['status'] == status:
        yield record


def compute_percentile(p, key, records):
    records = sorted(r['query_time'] for r in records)
    answer = dict(key)
    answer['percentile'] = records[int(p * len(records))]
    yield answer


# @debug_cache(
#     FileCache('/tmp/yt/_calculate_statuses_events'), log,
#     enable=True,
#     # reset=True,
# )
def _calculate_statuses_events(begin, end):
    yh.configure_wrapper(yt)

    source_tables = yh.safe_tables_for_daterange(
        yt, '//home/rasp/logs/rasp-partners-query-log', begin, end
    )

    with yh.temp_table(yt) as stats_table, yh.temp_table(yt) as percentile_table:
        log.debug('Temp table: %s', stats_table)

        with comment_process(log.debug, 'Map stats'):
            yt.run_map(
                map_to_partners_events,
                source_table=source_tables, destination_table=stats_table,
            )

        operation = yt.run_map_reduce(
            mapper=None,
            reducer=reduce_partners_events,
            source_table=stats_table,
            destination_table=stats_table,
            reduce_by=['partner', 'status', 'date_ymdh'],
            sync=False,
        )

        percentile_operation = yt.run_map_reduce(
            mapper=partial(filter_by_status, 'got_reply'),
            reducer=partial(compute_percentile, 0.9),
            reduce_by=['partner'],
            source_table=stats_table,
            destination_table=percentile_table,
            sync=False,
        )

        qids_count = count_uniq_keys(yt, stats_table, ['qid'])
        operation.wait()

        percentile_operation.wait()
        percentile_by_partner = {
            r['partner']: r['percentile'] / 1000
            for r in yt.read_table(percentile_table, format=yt.JsonFormat())
        }

        return qids_count, percentile_by_partner, list(yt.read_table(stats_table, format=yt.JsonFormat(), raw=False))


@contextmanager
def comment_process(logger_method, message, *args, **kwargs):
    logger_method('Start ' + message, *args, **kwargs)
    yield
    logger_method('Done ' + message, *args, **kwargs)


def reduce_yield_first(key, records):
    yield next(iter(records))


def count_uniq_keys(yt, table, keys):
    with yh.temp_table(yt) as keys_table:
        with comment_process(log.debug, 'Map reduce keys: %r', keys):
            yt.run_map_reduce(
                mapper=None,
                reducer=reduce_yield_first,
                source_table=table,
                destination_table=keys_table,
                reduce_by=keys,
            )

        return yt.get_attribute(keys_table, 'row_count')


def _get_partner_from_importer(importer):
    """ Importer has form module[partner_code] """
    left_bracket = importer.find('[')
    right_bracket = importer.find(']')
    if left_bracket == -1 or right_bracket == -1 or left_bracket + 1 >= right_bracket:
        return None

    partner = importer[left_bracket + 1: right_bracket]
    if partner.startswith('dohop'):
        return 'dohop'
    elif partner.startswith('amadeus'):
        return 'amadeus'
    return partner


def map_to_partners_events(record):
    importer = record.get('importer')
    qid = record.get('qid')
    status = record.get('status')
    date_ymdh = '%s' % record.get('iso_eventtime')[:13]

    if not all([importer, qid, status, date_ymdh]):
        return

    try:
        query_time = float(record['query_time'])
    except ValueError:
        return

    variants_len = record.get('variants_len')

    if status == 'got_reply' and variants_len and not int(variants_len):
        status = 'empty'

    partner = _get_partner_from_importer(importer)
    if not partner:
        return

    yield {
        'partner': partner,
        'status': status,
        'date_ymdh': date_ymdh,
        'query_time': query_time,
        'qid': qid,
    }


def reduce_partners_events(key, records):
    count = 0
    for r in records:
        count += 1

    ans = dict(key)
    ans['count'] = count
    yield ans
