# -*- coding: utf-8 -*-
import logging
from collections import OrderedDict
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from operator import itemgetter

import yt.wrapper as yt

import travel.avia.stat_admin.lib.yt_helpers as yh
from travel.avia.stat_admin.lib.grouping import group_tree

log = logging.getLogger(__name__)


def calculate_stat_by_partners(begin, end):
    qids_count, records = _calculate_statuses_events(begin, end)

    def preprocess(r):
        r['date_ymdh'] = datetime.strptime(r.pop('date_ymdh'), '%Y-%m-%d %H')
        r['count'] = int(r['count'])
        if 'avg_query_time' in r:
            r['avg_query_time'] = float(r['avg_query_time'])
        return r

    records = [preprocess(record) for record in records]

    by_partner = group_tree(records, 'partner.date_ymdh.status')

    return qids_count, {
        partner: sorted(
            gen_pstat(partner_by_hour_statuses),
            key=itemgetter('hour')
        )
        for partner, partner_by_hour_statuses in by_partner.items()
    }


def gen_pstat(by_hour_statuses):
    for hour, statuses in by_hour_statuses.items():
        def safe(status, value_field='count'):
            try:
                return statuses[status][0][value_field]
            except KeyError:
                return 0

        got_reply = safe('got_reply')
        empty = safe('empty')
        timeout = safe('timeout')
        response_error = safe('response_error')
        got_failure = safe('got_failure')

        queries = got_reply + empty + timeout + response_error + got_failure

        yield OrderedDict([
            ('hour', hour),

            ('queries', queries),

            ('good', got_reply),
            ('avg_query_time', safe('got_reply', 'avg_query_time')),

            ('empty', empty),
            ('timeout', timeout),
            ('error', got_failure + response_error),

            # ('redir_fails', 0),  # TODO
        ])


JOB_COUNT = 100


# @debug_cache(
#     FileCache('/tmp/yt/_calculate_statuses_events'), log,
#     enable=True,
#     # reset=True,
# )
def _calculate_statuses_events(begin, end):
    yh.configure_wrapper(yt)

    source_tables = yh.tables_for_daterange(
        yt, '//home/rasp/logs/rasp-partners-query-log', begin, end
    )

    allowed_services = ['ticket', 'api_avia']

    with yh.temp_table(yt) as stats_table:
        log.debug('Temp table: %s', stats_table)

        with comment_process(log.debug, 'Map stats'):
            yt.run_map(
                partial(map_to_partners_events, allowed_services),
                source_table=source_tables, destination_table=stats_table,
                format=yt.DsvFormat(), job_count=JOB_COUNT,
            )

        qids_count = count_uniq_keys(yt, stats_table, ['qid'])

        with comment_process(log.debug, 'Sort stats'):
            yt.run_sort(
                source_table=stats_table,
                sort_by=['partner', 'status', 'date_ymdh']
            )

        with comment_process(log.debug, 'Reduce stats'):
            yt.run_reduce(
                reduce_partners_events, stats_table, stats_table,
                reduce_by=['partner', 'status', 'date_ymdh'],
                format=yt.DsvFormat(), job_count=JOB_COUNT,
            )

        return qids_count, list(yt.read_table(stats_table, format=yt.DsvFormat(), raw=False))


@contextmanager
def comment_process(logger_method, message, *args, **kwargs):
    logger_method('Start ' + message, *args, **kwargs)
    yield
    logger_method('Done ' + message, *args, **kwargs)


def reduce_yield_first(key, records):
    yield next(iter(records))


def count_uniq_keys(yt, table, keys):
    def map_only_keys(record):
        yield {k: record.get(k) for k in keys}

    with yh.temp_table(yt) as keys_table:
        with comment_process(log.debug, 'Map keys: %r', keys):
            yt.run_map(
                map_only_keys, source_table=table, destination_table=keys_table,
                format=yt.DsvFormat(), job_count=JOB_COUNT)

        with comment_process(log.debug, 'Sort keys: %r', keys):
            yt.run_sort(source_table=keys_table, sort_by=keys)

        with comment_process(log.debug, 'Reduce keys: %r', keys):
            yt.run_reduce(
                reduce_yield_first, keys_table, keys_table, reduce_by=keys,
                format=yt.DsvFormat(), job_count=JOB_COUNT
            )

        return yt.get_attribute(keys_table, 'row_count')


def map_to_partners_events(allowed_services, record):
    importer = record.get('importer')
    qid = record.get('qid')
    status = record.get('status')
    date_ymdh = '%s' % record.get('iso_eventtime')[:13]

    if not all([importer, qid, status, date_ymdh]):
        return

    try:
        service = qid.split(':')[1]
    except Exception:
        return

    if service not in allowed_services:
        return

    importer_code = importer.split('[')[0]
    variants_len = record.get('variants_len')

    if status == 'got_reply' and variants_len and not int(variants_len):
        status = 'empty'
    yield {
        'partner': importer_code.rstrip('0123456789'),
        'status': status,
        'date_ymdh': date_ymdh,
        'query_time': record['query_time'],
        'qid': qid,
    }


def reduce_partners_events(key, records):
    count = 0
    query_time_sum = 0
    for r in records:
        try:
            query_time_sum += float(r['query_time'])
            count += 1
        except Exception:
            continue
    yield {
        'partner': key['partner'],
        'status': key['status'],
        'date_ymdh': key['date_ymdh'],
        'count': count,
        'avg_query_time': (float(query_time_sum) / 1000) / count,
    }
