# coding: utf8

from collections import defaultdict

from django.db.models import Q

from common.models.currency import Price
from common.models.tariffs import AeroexTariff
from common.models.geo import Station
from common.models.transport import TransportType
from common.models_abstract.schedule import ExpressType


class SegmentsTariffsProcessor(object):
    def __init__(self, segments):
        self._tariffs_by_ids = {}
        self._segments_by_stations = defaultdict(list)
        for segment in segments:
            self._add_segment(segment)

    def _add_segment(self, segment):
        if segment.t_type.code == 'suburban':
            stations = segment.station_from, segment.station_to
            self._segments_by_stations[stations].append(segment)

    def process_tariffs(self):
        """
        Формирование списка тарифов для списка сегментов
        """
        # Обработка идет отдельно для каждой пары станций
        for stations, segments in self._segments_by_stations.items():
            tariffs_by_code = {}
            replaced_tariffs_by_code = {}
            used_group_ids = set()  # Используемые группы тарифов

            # Получение тарифов из базы
            tariffs = get_tariffs_by_stations(*stations)
            for tariff in tariffs.values():
                # Добавляем замещающий тип тарифа (replace_tariff_type), если есть
                if tariff.replace_tariff_type:
                    segment_tariff = SegmentTariff(tariff, tariff.replace_tariff_type)
                    tariffs_by_code[segment_tariff.type.code] = segment_tariff

                # Добавляем тип тарифа (type), если был замещен, то добавляем в replaced_tariffs_by_code
                segment_tariff = SegmentTariff(tariff, tariff.type)
                tariffs_dict = replaced_tariffs_by_code if tariff.replace_tariff_type else tariffs_by_code
                tariffs_dict[segment_tariff.type.code] = segment_tariff

            # Формирование базовых тарифов для сегментов
            for segment in segments:
                if segment.thread.tariff_type:
                    # базовый тариф указан в нитке
                    type_code = segment.thread.tariff_type.code
                else:
                    type_code = 'express' if segment.thread.express_type in [ExpressType.AEROEXPRESS, ExpressType.EXPRESS] else 'etrain'

                for tariffs_by_code_dict in tariffs_by_code, replaced_tariffs_by_code:
                    if type_code in tariffs_by_code_dict:
                        segment.base_tariff = tariffs_by_code_dict[type_code]

                # Формируем список групп базовых тарифов сегментов
                if hasattr(segment, 'base_tariff'):
                    used_group_ids = used_group_ids | segment.base_tariff.group_ids

            # Добавление в итоговый список тарифов
            for segment_tariff in sorted(tariffs_by_code.values(), key=lambda t: t.type.order):
                if segment_tariff.has_one_of_groups(used_group_ids):
                    self._tariffs_by_ids[segment_tariff.tariff.id] = segment_tariff

            # Формирование списков индексов тарифов для сегментов
            for segment in segments:
                segment.tariffs_ids = []
                if hasattr(segment, 'base_tariff'):
                    for segment_tariff in tariffs_by_code.values():
                        if segment.base_tariff.has_one_of_groups(segment_tariff.group_ids):
                            segment.tariffs_ids.append(segment_tariff.tariff.id)
                segment.tariffs_ids.sort()

    def get_tariffs_data(self):
        """
        :return: Список тарифов в формате JSON
        """
        return [tariff.get_data() for tariff in self._tariffs_by_ids.values()]


class SegmentTariff(object):
    """
    Тариф. Задается тарифом из таблицы aeroextariff и типом тарифа из tarifftype
    """
    def __init__(self, tariff, tariff_type):
        self.tariff = tariff
        self.type = tariff_type
        self.group_ids = {group.id for group in self.type.tariff_groups.all()}

    def has_one_of_groups(self, group_ids):
        for group_id in group_ids:
            if group_id in self.group_ids:
                return True
        return False

    def get_data(self):
        """
        Выдача для JSON
        """
        return {
            'price': self.get_price_data(),
            'id': self.tariff.id,
            'title': self.type.L_title(),
            'code': self.type.code,
            'description': self.type.L_description(),
            'order': self.type.order,
            'url': self.type.link,
            'category': self.type.category,
            'is_main': self.type.is_main
        }

    def get_price_data(self):
        """
        Выдача цены для JSON
        """
        price = Price(self.tariff.tariff, self.tariff.currency)
        return {'currency': price.currency,
                'value': round(self.tariff.tariff, 2)}


def get_segment_tariff(segment):
    """
    Формирование базового тарифа сегмента
    """
    if hasattr(segment, 'base_tariff'):
        return segment.base_tariff.get_price_data()
    return {}


def get_tariffs_by_stations(station_from, station_to):
    # Ручные тарифы имеют приоритет над предрасчитанными.
    # Тариф с наибольшим id верный.
    tariffs_qs = AeroexTariff.objects.filter(
        Q(station_from=station_from, station_to=station_to) |
        Q(station_to=station_from, station_from=station_to, reverse=True)
    ).order_by('-precalc', 'id')
    return {tariff.type.code: tariff for tariff in tariffs_qs}


def get_suburban_tariffs(segments):
    """
    https://st.yandex-team.ru/RASPFRONT-7275
    Тарифы для электричек
    """
    processor = SegmentsTariffsProcessor(segments)
    processor.process_tariffs()
    return processor.get_tariffs_data()


class ThreadTariffSegment(object):
    # Сегмент для тарифов формы нитки
    def __init__(self, thread, station_from_id, station_to_id):
        self.thread = thread
        self.station_from = Station.objects.get(id=station_from_id)
        self.station_to = Station.objects.get(id=station_to_id)
        self.t_type = TransportType.objects.get(id=TransportType.SUBURBAN_ID)


def get_thread_suburban_tariffs(thread, station_from_id, station_to_id):
    """
    https://st.yandex-team.ru/RASPFRONT-7275
    Полный список тарифов электричек для формы нитки
    :return: Возвращает список тарифов и цену базового тарифа
    """
    segment = ThreadTariffSegment(thread, station_from_id, station_to_id)
    return get_suburban_tariffs([segment]), get_segment_tariff(segment)
