# coding: utf-8

from collections import defaultdict

from django.db.models import Q
from xml.etree import ElementTree

from common.models.currency import Price
from common.models.tariffs import AeroexTariff
from common.models_abstract.schedule import ExpressType
from common.views.tariffs import StaticTariffs, AuxTariffs

from travel.rasp.export.export.views.utils import add_subelement


def create_tariff_info(tariff):
    return StaticTariffs(Price(tariff.tariff))


# в Беларуси нет копеек, поэтому приводим тариф к инту
def prepare_tarrif(price, currency):
    return str(int(price)) if currency == 'BYR' else '{0:.2f}'.format(price)


def add_suburban_tariffs(segments):
    # Электрички, по парам станций отправления-прибытия
    suburban_trains = defaultdict(list)
    for segment in segments:
        if segment.t_type.code == 'suburban':
            stations = segment.station_from, segment.station_to
            suburban_trains[stations].append(segment)

    tariffs_by_stations = {}
    if suburban_trains:
        for stations, suburban_segments in suburban_trains.items():
            tariffs = get_tariffs_by_stations(*stations) if suburban_segments else {}
            tariffs_by_stations[stations] = tariffs

            for segment in suburban_segments:
                tariff_type = segment.thread.tariff_type

                if tariff_type:
                    code = segment.thread.tariff_type.code
                else:
                    code = 'express' if segment.thread.express_type in [ExpressType.AEROEXPRESS, ExpressType.EXPRESS] else 'etrain'

                if code in tariffs:
                    thread_code_tariff = tariffs[code]
                    segment.tariff_info = create_tariff_info(thread_code_tariff)

                    segment.aux_tariffs = AuxTariffs(thread_code_tariff)
                    for t in sorted(tariffs.values(), key=lambda t: t.type.order):
                        segment.aux_tariffs.setdefault(t.type.category, []).append(t)

    return tariffs_by_stations


def get_tariffs_by_stations(station_from, station_to):
    # Ручные тарифы имеют приоритет над предрасчитанными.
    # Тариф с наибольшим id верный.
    tariffs_qs = AeroexTariff.objects.filter(
        Q(station_from=station_from, station_to=station_to) |
        Q(station_to=station_from, station_from=station_to, reverse=True)
    ).order_by('-precalc', 'id')
    return {tariff.type.code: tariff for tariff in tariffs_qs}


def build_tariffs_by_stations(tariffs_by_stations):
    tariff_groups = ElementTree.Element('tariffs_by_stations')
    for (station_from, station_to), tariffs_by_type in tariffs_by_stations.items():
        tariffs_el = add_subelement(tariff_groups, 'tariffs')
        tariffs_el.attrib['station_from'] = station_from.get_code('esr')
        tariffs_el.attrib['station_to'] = station_to.get_code('esr')

        for tariff_type, tariff in tariffs_by_type.items():
            tariff_info = create_tariff_info(tariff)
            if len(tariff_info.places) == 1:
                price = tariff_info.places[0].tariff
                tariff_el = add_subelement(tariffs_el, 'tariff')
                tariff_el.attrib['currency'] = tariff.currency or price.currency
                tariff_el.attrib['tariff'] = prepare_tarrif(tariff.tariff, tariff_el.attrib['currency'])
                tariff_el.attrib['type'] = tariff.type.code
    return tariff_groups


def set_segment_tariff(segment_element, segment):
    tariff_info = getattr(segment, 'tariff_info', None)
    aux_tariffs = getattr(segment, 'aux_tariffs', None)

    if tariff_info and len(tariff_info.places) == 1:
        tariff = tariff_info.places[0].tariff
        segment_element.attrib['currency'] = aux_tariffs.base.currency or tariff.currency
        segment_element.attrib['tariff'] = prepare_tarrif(aux_tariffs.base.tariff, segment_element.attrib['currency'])
