# -*- coding: utf-8

import logging
from datetime import timedelta

from django.db.models import Q

from common.models.currency import Price
from common.models.schedule import Supplier, RThread
from common.models.tariffs import ThreadTariff
from common.models.transport import TransportType
from common.models_utils import fetch_related
from travel.rasp.library.python.common23.date import environment
from common.utils.date import RunMask
from travel.rasp.morda.morda.order.views.partners.main import PARTNER_MODULES


log = logging.getLogger('rasp.tariffs.bus')


def can_buy_from(request, supplier_code, segment,
                 order_data=None, is_static_price=True):
    segment_departure = segment.departure
    now = environment.now_aware()
    supplier = None

    if is_static_price:
        try:
            supplier = Supplier.objects.get(code=supplier_code)
        except (Supplier.DoesNotExist, Supplier.MultipleObjectsReturned):
            return False

        if not supplier.is_sale_enabled(request.NATIONAL_VERSION):
            return False

    # RASP-9937
    if segment_departure < now:
        return False

    departure_gap = segment_departure - now

    if supplier_code == 'buscomua':
        # RASP-11127
        return departure_gap >= timedelta(hours=1)

    if supplier_code == 'swdfactory':
        return departure_gap > timedelta(minutes=5)

    # Для динамических цен сюда приходит partner
    if supplier_code == 'ukrmintrans_bus':
        return timedelta(hours=1) < departure_gap < timedelta(days=45)

    static_module = PARTNER_MODULES.get(supplier_code)
    if static_module is not None:
        can_buy_context = {
            'now': now,
            'segment': segment,
            'order_data': order_data,
            'is_static_price': is_static_price
        }
        return static_module.can_buy_from(request, supplier_code, can_buy_context)

    if supplier is not None:
        if not supplier.sale_url_template:
            return False

        sale_start_days = supplier.sale_start_days
        sale_stop_hours = supplier.sale_stop_hours

        return (
            not sale_start_days or
            departure_gap < timedelta(days=sale_start_days)
        ) and (
            not sale_stop_hours or
            departure_gap > timedelta(hours=sale_stop_hours)
        )

    return False


def add_tariffs(segments, request=None, currency_rates=None):
    """Стоимость рейсов"""

    ttype_ids_with_tariffs = TransportType.WATER_TTYPE_IDS + [TransportType.BUS_ID, TransportType.TRAIN_ID]
    segments = [s for s in segments if s.t_type.id in ttype_ids_with_tariffs and getattr(s, 'thread', None)]

    if segments:
        tariffs_dict = get_tarriffs_dict(segments)
        fetch_related([segment.thread for segment in segments], 'supplier', model=RThread)

        for segment in segments:
            tariff_key = (segment.thread.uid, segment.station_from.id, segment.station_to.id)

            supplier_code = segment.thread.supplier.code

            segment_tariffs = tariffs_dict.get(tariff_key, {})

            segment_tariff = None

            for year_days, mask_tariff in segment_tariffs.items():
                if RunMask.runs_at(year_days, segment.start_date):
                    segment_tariff = mask_tariff

                    break

            if segment_tariff:
                tariff = Price(segment_tariff.tariff, segment_tariff.currency)

                tariff = calc_final_price(supplier_code, tariff, currency_rates)

                order_data = segment_tariff.get_order_data()

                linked = request and can_buy_from(
                    request,
                    supplier_code,
                    segment,
                    order_data=order_data,
                    is_static_price=True
                )

                segment.display_info.set_tariff(
                    tariff,
                    linked=linked,
                    order_data=order_data,
                    is_min_tariff=segment_tariff.is_min_tariff
                )


def get_tarriffs_dict(segments):
    query = Q()
    for segment in segments:
        query |= Q(thread_uid=segment.thread.uid,
                   station_from=segment.station_from,
                   station_to=segment.station_to
                   )

    tariffs_dict = {}
    for thread_tariff in ThreadTariff.objects.filter(query):
        tariff_key = (thread_tariff.thread_uid, thread_tariff.station_from_id, thread_tariff.station_to_id)

        tariffs_dict.setdefault(tariff_key, {}).update({thread_tariff.year_days: thread_tariff})

    return tariffs_dict


def calc_final_price(supplier_code, tariff, currency_rates):
    module = PARTNER_MODULES.get(supplier_code)

    if not hasattr(module, 'calc_final_price'):
        return tariff

    return module.calc_final_price(tariff, currency_rates)
