# coding: utf-8

from collections import defaultdict
from itertools import groupby, chain
from operator import attrgetter

from common.models.geo import StationMajority
from common.models.tariffs import TariffType, AeroexTariff, TariffGroup
from common.models.transport import TransportType
from travel.rasp.morda_backend.morda_backend.tariffs.train.base.utils import make_suburban_express_keys


def get_suburban_tariffs(point_from, point_to):
    tariffs = []
    groups = []

    for (station_from_id, station_to_id), direction_tariffs in iter_sorted_tariffs_by_direction(point_from, point_to):
        tariffs_by_category = split_tariffs_by_categories(direction_tariffs)

        for tariff in tariffs_by_category.get(TariffType.USUAL_CATEGORY, ()):
            tariffs.append({
                'classes': {'suburban': tariff},
                'key': make_suburban_tariff_key(station_from_id, station_to_id, tariff.type_id),
                'suburban_categories': leave_only_tariffs_of_selected_groups(tariffs_by_category,
                                                                             selected_group_ids=tariff.group_ids)
            })

        usual_group_ids = set()
        for tariff in tariffs_by_category.get(TariffType.USUAL_CATEGORY, ()):
            usual_group_ids |= tariff.group_ids

        for group in TariffGroup.objects.filter(id__in=usual_group_ids).order_by('id'):
            groups.append({
                'station_from_id': station_from_id,
                'station_to_id': station_to_id,
                'id': group.id,
                'title': group.L_title(),
                'categories': leave_only_tariffs_of_selected_groups(tariffs_by_category, selected_group_ids={group.id})
            })

    return tariffs, groups


def iter_sorted_tariffs_by_direction(point_from, point_to):
    station_ids_from = get_suburban_tariff_station_ids_from_point(point_from)
    station_ids_to = get_suburban_tariff_station_ids_from_point(point_to)
    forward_tariffs = list(build_aeroextariff_queryset(station_ids_from, station_ids_to, search_backward=False))
    backward_tariffs = list(build_aeroextariff_queryset(station_ids_to, station_ids_from, search_backward=True))

    tariffs_by_direction = defaultdict(list)
    for tariff in forward_tariffs:
        tariffs_by_direction[tariff.station_from_id, tariff.station_to_id].append(tariff)
    for tariff in backward_tariffs:
        tariffs_by_direction[tariff.station_to_id, tariff.station_from_id].append(tariff)

    for direction, tariffs in tariffs_by_direction.iteritems():
        # Исключаем предрасчитанные тарифы когда есть ручные.
        tariffs = [
            min(type_tariffs, key=attrgetter('precalc'))
            for _type_id, type_tariffs in groupby(sorted(tariffs, key=attrgetter('type_id')), key=attrgetter('type_id'))
        ]

        for tariff in tariffs:
            tariff.group_ids = {group.id for group in tariff.type.tariff_groups.all()}

        yield direction, sorted(tariffs, key=attrgetter('type.order'))


def build_aeroextariff_queryset(station_ids_from, station_ids_to, search_backward):
    tariffs_query = AeroexTariff.objects.filter(station_from__in=station_ids_from, station_to__in=station_ids_to)
    if search_backward:
        tariffs_query = tariffs_query.filter(reverse=True)
    return tariffs_query.prefetch_related('type', 'type__tariff_groups', 'replace_tariff_type')


def get_suburban_tariff_station_ids_from_point(point):
    """
    Оставляем только станции поездов и аэропорты, т.к. между остальными типами станций у нас пока нет электричечных
    тарифов.
    TODO: без этой оптимизации выдача тарифов будет очень медленной, поэтому ее убирать без адекватной замены нельзя.
    Нужно научиться быстро определять, между какими станциями есть тарифы.
    """
    if point.is_station:
        return [point.id]
    return list(chain(
        point.station_set.filter(
            majority__lte=StationMajority.STATION_ID, hidden=False,
            t_type__id__in=(TransportType.TRAIN_ID, TransportType.PLANE_ID)
        ).values_list('id', flat=True),
        point.station2settlement_set.filter(
            station__majority__lte=StationMajority.STATION_ID, station__hidden=False,
            station__t_type__id__in=(TransportType.TRAIN_ID, TransportType.PLANE_ID)
        ).values_list('station', flat=True)
    ))


def split_tariffs_by_categories(tariffs):
    tariffs_by_category = defaultdict(list)
    for tariff in tariffs:
        tariffs_by_category[tariff.type.category].append(tariff)
    return tariffs_by_category


def leave_only_tariffs_of_selected_groups(tariffs_by_category, selected_group_ids):
    actual_tariffs_by_category = {}
    for category, possible_tariffs in tariffs_by_category.iteritems():
        actual_tariffs_by_category[category] = [t for t in possible_tariffs if t.group_ids & selected_group_ids]
    return actual_tariffs_by_category


def make_suburban_tariff_key(station_from_id, station_to_id, tariff_type_id):
    return 'suburban {} {} {}'.format(station_from_id, station_to_id, tariff_type_id)


def make_suburban_segment_keys(rthread_segment):
    return make_suburban_thread_keys(
        rthread_segment.station_from.id,
        rthread_segment.station_to.id,
        rthread_segment.thread,
        rthread_segment.departure,
        getattr(rthread_segment, 'train_purchase_numbers', None)
    )


def make_suburban_thread_keys(station_from_id, station_to_id, thread, departure_dt, train_purchase_numbers):
    if thread.t_type_id != TransportType.SUBURBAN_ID:
        return []

    tariff_type_id = thread.get_tariff_type_id()
    keys = [make_suburban_tariff_key(station_from_id, station_to_id, tariff_type_id)]

    if train_purchase_numbers:
        keys += make_suburban_express_keys(departure_dt, train_purchase_numbers)

    return keys
