# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from collections import Counter
from datetime import datetime
from functools import partial

from sandbox.projects.avia.lib import yt_helpers

log = logging.getLogger(__name__)

DATE_FORMAT = '%Y-%m-%d'
TIMESTAMP_FORMAT = '%Y-%m-%d %H:%M:%S'
ACCEPTED_CURRENCIES = ['RUR']
POPULAR_THRESHOLD = 5
MAX_DAYS = 64

YT_TARIFFS_LOG_DIR = '//home/rasp/logs/rasp-tariffs-log'
YT_SHOW_LOGS_DIR = '//home/rasp/logs/rasp-tickets-show-log'

QUANTILES = [0.33, 0.67]


def get_route_uid(forward, backward):
    if backward == '':
        return convert_route(forward)
    return '{0};{1}'.format(convert_route(forward), convert_route(backward))


def take_shows(key, records):
    route_uid = get_route_uid(key['forward'], key['backward']).upper()
    show_ids = set()
    for record in records:
        show_id = record['show_id'].split('.')[0]
        if show_id not in show_ids:
            show_ids.add(show_id)

    new_record = {
        'route_uid': route_uid,
        'n_shows': len(show_ids),
    }

    yield new_record


def summator(field, key, records):
    value = sum(int(record[field]) for record in records)
    new_record = dict(key)
    new_record[field] = value
    yield new_record


def update_shows(ytc, shows_storage, start_date, end_date):
    source_tables = yt_helpers.tables_for_daterange(
        ytc, YT_SHOW_LOGS_DIR, start_date, end_date
    )

    temp_shows_storage = ytc.create_temp_table()

    for daily_table in source_tables:
        ytc.run_map_reduce(
            mapper=None,
            reducer=take_shows,
            reduce_by=['forward', 'backward'],
            source_table=daily_table,
            destination_table='<append=true>' + temp_shows_storage,
        )

    ytc.run_map_reduce(
        mapper=None,
        reducer=partial(summator, 'n_shows'),
        reduce_by='route_uid',
        source_table=[shows_storage, temp_shows_storage],
        destination_table=[shows_storage],
    )


def convert_route(route):
    return route.replace(', ', ';')


def get_bin_for_days(days):
    if days >= MAX_DAYS:
        return MAX_DAYS

    if days <= 1:
        return 1

    return 2 ** (days.bit_length() - 1)


def check_currency(price_str):
    try:
        currency = price_str.split()[1]
    except Exception:
        return False

    return currency in ACCEPTED_CURRENCIES


class RecordChecker(object):
    def __init__(self, good_services):
        self.good_services = good_services

    def check_record(self, record):
        try:
            service = record['qid'].split('.')[1]
            if service not in self.good_services:
                return False
        except Exception:
            return False

        if (
            record.get('type') != 'plane'
            or record.get('national_version') != 'ru'
            or record.get('adults') != '1'
            or record.get('children') != '0'
            or record.get('infants') != '0'
            or not check_currency(record.get('class_economy_price'))
        ):
            return False

        try:
            datetime.strptime(record['date_forward'], DATE_FORMAT)
            if record['date_backward'] != 'None':
                datetime.strptime(record['date_backward'], DATE_FORMAT)
        except Exception:
            return False

        return True


def _get_point_type_code(point_type):
    if point_type.startswith('Settlement'):
        return 'c'
    elif point_type.startswith('Station'):
        return 's'
    else:
        raise ValueError('Unknown Point type', point_type)


def filter_records(records, good_services):
    record_checker = RecordChecker(good_services)
    qid_routes_to_records = {}
    for record in records:
        try:
            if not record_checker.check_record(record):
                continue

            date_forward = datetime.strptime(record['date_forward'], '%Y-%m-%d')
            timestamp = datetime.strptime(record['timestamp'], TIMESTAMP_FORMAT)
            hours_to_flight = to_hours(date_forward - timestamp)

            new_record = {
                'point_from_id': record['object_from_id'],
                'point_to_id': record['object_to_id'],
                'point_from_type': _get_point_type_code(record['object_from_type']),
                'point_to_type': _get_point_type_code(record['object_to_type']),
                'route_uid': record['route_uid'].upper(),
                'hours_to_flight': hours_to_flight,
                'days_to_flight': get_bin_for_days(hours_to_flight // 24),
                'departure_weekday': date_forward.weekday(),
                'class_economy_price': float(record['class_economy_price'].split()[0]),
                'qid': record['qid'],
            }

            key = record['qid'], record['route_uid']
            if key in qid_routes_to_records:
                min_record = qid_routes_to_records[key]
                if min_record['class_economy_price'] > new_record['class_economy_price']:
                    qid_routes_to_records[key] = new_record
            else:
                qid_routes_to_records[key] = new_record

        except KeyError:
            pass

    for record in qid_routes_to_records.values():
        yield record


def take_min_by_price(_key, records):
    min_record = min(
        records,
        key=lambda record: float(record['class_economy_price'])
    )

    # We don't need qid anymore
    del min_record['qid']

    yield min_record


class GEQTaker(object):
    def __init__(self, field, value):
        self.field = field
        self.value = value

    def __call__(self, record):
        if record[self.field] >= int(self.value):
            yield record


class SetFilter(object):
    def __init__(self, field, target_set):
        self.field = field
        self.target_set = target_set

    def __call__(self, record):
        if record[self.field] in self.target_set:
            yield record


def take_popular(ytc, yt_shows_table, source_table, destination_table):
    import yt.wrapper

    popular_flights_table = ytc.create_temp_table()
    ytc.run_map(
        binary=GEQTaker('n_shows', POPULAR_THRESHOLD),
        source_table=yt_shows_table,
        destination_table=popular_flights_table,
    )

    popular_flights = set(
        record['route_uid'].upper()
        for record in ytc.read_table(
            popular_flights_table, format=yt.wrapper.DsvFormat()
        )
    )

    ytc.run_map(
        binary=SetFilter('route_uid', popular_flights),
        source_table=source_table,
        destination_table=destination_table,
        memory_limit=1 * 1024 * 1024 * 1024,
    )


def to_hours(td):
    return int(td.total_seconds() // 3600)


def take_distinct_prices(_key, records):
    prices = set()
    for record in records:
        price = float(record['class_economy_price'])
        if price not in prices:
            prices.add(price)
            yield record


def get_quantiles(prices, quantiles):
    answer = {}
    total_prices = sum(int(number) for _, number in prices)

    quantile_ind = 0
    price_ind = 0
    current_sum = 0
    while quantile_ind < len(quantiles):
        current_quantile = quantiles[quantile_ind]
        while current_sum < int(current_quantile * total_prices + 0.5):
            current_sum += prices[price_ind][1]
            price_ind += 1

        column_name = 'q{}'.format(int(current_quantile * 100))
        if price_ind == 0:
            answer[column_name] = prices[0][0]
        else:
            answer[column_name] = prices[price_ind - 1][0]
        quantile_ind += 1

    return answer


def get_raw_quantile(lst, quantile):
    max_ind = len(lst) - 1
    ind = int(quantile * max_ind + 0.5)
    return lst[ind]


def quantiles_reducer(key, records):
    prices = [
        (float(record['class_economy_price']), int(record['number']))
        for record in records
    ]
    prices = sorted(prices, key=lambda x: x[0])

    #  Change types
    answer = {
        'point_from_type': key['point_from_type'],
        'point_from_id': int(key['point_from_id']),
        'point_to_type': key['point_to_type'],
        'point_to_id': int(key['point_to_id']),
        'days_to_flight': int(key['days_to_flight']),
        'route_uid': key['route_uid'].upper(),
        'departure_weekday': int(key['departure_weekday']),
    }

    quantiles = get_quantiles(prices, QUANTILES)
    answer.update(quantiles)

    try:
        raw_quantiles = {
            'qr{}'.format(int(q * 100)): get_raw_quantile(prices, q)[0]
            for q in QUANTILES
        }
        answer.update(raw_quantiles)
    except Exception:
        pass

    yield answer


def create_quantile_table(ytc, table_path):
    log.info('Create quantile table %s', table_path)
    schema = [
        {'name': 'point_from_type', 'type': 'string'},
        {'name': 'point_from_id', 'type': 'int64'},
        {'name': 'point_to_type', 'type': 'string'},
        {'name': 'point_to_id', 'type': 'int64'},
        {'name': 'route_uid', 'type': 'string'},
        {'name': 'days_to_flight', 'type': 'int64'},
        {'name': 'departure_weekday', 'type': 'int64'},
    ]

    schema.extend([
        {'name': 'q{}'.format(int(q * 100)), 'type': 'double'}
        for q in QUANTILES
    ])

    schema.extend([
        {'name': 'qr{}'.format(int(q * 100)), 'type': 'double'}
        for q in QUANTILES
    ])

    ytc.create(
        'table',
        table_path,
        attributes={'schema': schema, 'optimize_for': 'scan'}
    )
    log.info('Created quantile table %s', table_path)


def update_quantiles(ytc, source_table, destination_table):
    if not ytc.exists(destination_table):
        create_quantile_table(ytc, destination_table)

    ytc.run_map_reduce(
        mapper=None,
        reducer=quantiles_reducer,
        source_table=source_table,
        destination_table=destination_table,
        reduce_by=[
            'point_from_type', 'point_from_id',
            'point_to_type', 'point_to_id',
            'departure_weekday', 'days_to_flight', 'route_uid',
        ],
    )


def update_price_storage(ytc, start_date, end_date, good_services, storage_table):
    import yt.wrapper

    source_tables = yt_helpers.tables_for_daterange(
        ytc, YT_TARIFFS_LOG_DIR, start_date, end_date
    )

    # Filtering raw logs
    filtered_rows = ytc.create_temp_table()
    ytc.run_map_reduce(
        mapper=yt.wrapper.aggregator(partial(filter_records, good_services=good_services)),
        reducer=take_min_by_price,
        source_table=source_tables,
        destination_table=filtered_rows,
        reduce_by=['route_uid', 'qid'],
        mapper_memory_limit=10 * 1024 * 1024 * 1024,
    )

    # Select distinct prices
    distinct_prices = ytc.create_temp_table()
    ytc.run_map_reduce(
        mapper=None,
        reducer=take_distinct_prices,
        source_table=filtered_rows,
        destination_table=distinct_prices,
        reduce_by=['route_uid', 'date_forward', 'date_backward', 'hours_to_flight'],
        reduce_combiner=take_distinct_prices,
    )

    # Count prices.
    # Merge flights with same departure_weekday and days_to_flight

    source_tables = [distinct_prices]
    if ytc.exists(storage_table):
        source_tables.append(storage_table)

    ytc.run_map_reduce(
        mapper=None,
        reducer=count_prices,
        reduce_combiner=count_prices,
        source_table=source_tables,
        destination_table=storage_table,
        reduce_by=[
            'point_from_type', 'point_from_id',
            'point_to_type', 'point_to_id',
            'departure_weekday', 'days_to_flight', 'route_uid',
        ],
    )


def count_prices(key, records):
    prices = Counter()

    new_key = dict(key)
    new_key['point_from_id'] = int(new_key['point_from_id'])
    new_key['point_to_id'] = int(new_key['point_to_id'])

    for record in records:
        price = float(record['class_economy_price'])
        number = record.get('number', 1)
        prices[price] += int(number)

    for price, number in prices.items():
        new_record = dict(new_key)
        new_record.update({
            'class_economy_price': price,
            'number': number,
        })
        yield new_record


def export_yt_table_to_dsv(ytc, input_table, columns, output_filename):
    import yt.wrapper

    yt_stream = ytc.read_table(
        input_table,
        raw=True,
        format=yt.wrapper.SchemafulDsvFormat(columns=columns)
    )
    with open(output_filename, 'wb') as f:
        for chunk in yt_stream.chunk_iter():
            f.write(chunk)
