# coding: utf-8

from __future__ import unicode_literals

from collections import defaultdict
from datetime import timedelta

from common.dynamic_settings.core import DynamicSetting
from common.dynamic_settings.default import conf
from common.models.currency import Price, Currency
from common.models.tariffs import ThreadTariff
from travel.rasp.library.python.common23.date import environment
from common.utils.currency_converter import NATIONAL_CURRENCY_RATES_GEOID, CURRENCY_RATES_GEOID
from common.utils.date import MSK_TZ, FuzzyDateTime
from common.utils.locations import set_lang_param
from route_search.helpers import LimitConditions
from travel.rasp.morda_backend.morda_backend.tariffs.static.partners import PARTNER_MODULES


conf.register_settings(
    MORDA_BACKEND_ADD_TARIFF_FEE=DynamicSetting(
        False,
        cache_time=60,
        description='Включает добавление комисси к тарифам партнеров в ручке tariffs/static-tariffs')
)


def get_static_tariffs(query):
    """
    :type query: morda_backend.tariffs.static.serialization.StaticTariffQuery
    """
    if not query.dates:
        query = query._replace(dates=[environment.now_aware().astimezone(query.point_from.pytz).date()])

    limit_conditions = LimitConditions(
        query.point_from,
        query.point_to,
        query.transport_types
    )
    tariffs_qs = ThreadTariff.objects.all()
    if query.thread_uid:
        tariffs_qs = tariffs_qs.filter(thread_uid=query.thread_uid)
    tariffs_qs = limit_conditions.filter_query_set(tariffs_qs, index_prefix='')
    tariffs = list(tariffs_qs.prefetch_related('settlement_from', 'settlement_to',
                                               'station_from', 'station_to', 'supplier'))
    for tariff in tariffs:
        for query_date in query.dates:
            departure = tariff.get_local_departure(query_date)
            if departure is None:
                continue

            class_tariff = {
                'price': Price(tariff.tariff, tariff.currency),
                'several_prices': tariff.is_min_tariff
            }

            static_module = PARTNER_MODULES.get(tariff.supplier.code)

            if conf.MORDA_BACKEND_ADD_TARIFF_FEE and getattr(static_module, 'add_fee', None):
                class_tariff['price'] = Price(static_module.add_fee(tariff.tariff), tariff.currency)

            now = environment.now_aware()
            order_data = tariff.get_order_data()
            if can_buy_from(tariff.supplier, now, departure, query.national_version, order_data):
                if static_module and getattr(static_module, 'get_order_request', None):
                    class_tariff['partner_order_request'] = static_module.get_order_request(departure, now, order_data)
                class_tariff['order_url'] = get_order_url(tariff, departure,
                                                          query.point_from, query.point_to,
                                                          query.national_version)

            yield {
                'key': make_static_tariff_key(tariff, departure),
                'classes': {tariff.t_type.code: class_tariff},
                'supplier': tariff.supplier,
            }


def make_thread_key(departure, station_from_id, station_to_id, thread_uid):
    dep_dt = departure.dt if isinstance(departure, FuzzyDateTime) else departure
    return 'static {} {} {} {}'.format(
        station_from_id, station_to_id, thread_uid, '{:%m%d}'.format(dep_dt) if dep_dt else ''
    )


def make_static_tariff_key(tariff, departure):
    return make_thread_key(departure, tariff.station_from_id, tariff.station_to_id, tariff.thread_uid)


def get_order_url(tariff, departure, point_from, point_to, national_version):
    departure_msk = departure.astimezone(MSK_TZ)
    arrival_msk = departure_msk + timedelta(minutes=tariff.duration)
    return '/buy/?' + set_lang_param({
        'station_from': tariff.station_from_id,
        'station_to': tariff.station_to_id,
        'departure': departure_msk.replace(tzinfo=None),
        'arrival': arrival_msk.replace(tzinfo=None),
        'title': '{} - {}'.format(*(
            (station.settlement or station).L_title() for station in (tariff.station_from, tariff.station_to)
        )),
        'date': departure.date(),
        'number': tariff.number,
        't_type': tariff.t_type.code,
        'thread': tariff.thread_uid,
        'point_from': point_from.point_key,
        'point_to': point_to.point_key,
        'tariff': Price(tariff.tariff, tariff.currency)
    }, national_version=national_version)


def can_buy_from(supplier, now, departure, national_version, order_data=None):
    if not supplier.is_sale_enabled(national_version) or departure < now:
        return False

    static_module = PARTNER_MODULES.get(supplier.code)
    if static_module is not None:
        return order_data is not None and static_module.can_buy_from(departure, now, order_data)

    if not supplier.sale_url_template:
        return False

    departure_gap = departure - now
    sale_start_days = supplier.sale_start_days
    sale_stop_hours = supplier.sale_stop_hours
    return (
        not sale_start_days or
        departure_gap < timedelta(days=sale_start_days)
    ) and (
        not sale_stop_hours or
        departure_gap > timedelta(hours=sale_stop_hours)
    )


def make_static_segment_keys(segment):
    # предполагаем, что segment.departure - время во временной зоне станции, нужно чтобы ключи тарифов
    # без нитки склеились с ключами сегментов
    return [
        make_thread_key(
            segment.departure, segment.station_from.id, segment.station_to.id,
            segment.thread.uid
        )
    ]


def get_static_min_tariffs(query):
    geo_id = NATIONAL_CURRENCY_RATES_GEOID.get(query.national_version, CURRENCY_RATES_GEOID)

    currencies = list(Currency.get_ordered_queryset(query.national_version))
    __, rates = Currency.fetch_rates(currencies, geo_id, 'RUB')

    limit_conditions = LimitConditions(
        query.point_from,
        query.point_to,
        query.transport_types
    )
    tariffs_qs = ThreadTariff.objects.all()
    tariffs_qs = limit_conditions.filter_query_set(tariffs_qs, index_prefix='')

    min_tariffs_by_key = defaultdict(dict)

    for tariff in tariffs_qs:
        key = make_static_min_tariff_key(tariff)
        class_ = tariff.t_type.code
        price = tariff.get_price()
        price.rebase(rates)

        min_tariffs = min_tariffs_by_key[key]
        min_tariffs[class_] = {
            'price':
                min(price, min_tariffs[class_]['price'])
                if class_ in min_tariffs else
                price
        }

    return [{'key': k, 'classes': classes} for k, classes in min_tariffs_by_key.items()]


def make_static_min_tariff_key(tariff):
    return 'static {} {} {}'.format(tariff.station_from_id, tariff.station_to_id, tariff.thread_uid)


def make_min_static_segment_keys(segment):
    return [
        'static {} {} {}'.format(segment.station_from.id, segment.station_to.id, segment.thread.uid)
    ]


def make_min_static_thread_keys(station_from_id, station_to_id, thread):
    return [
        'static {} {} {}'.format(station_from_id, station_to_id, thread.uid)
    ]
