# coding: utf-8
from __future__ import unicode_literals, absolute_import, division, print_function

from collections import defaultdict

from common.models.transport import TransportType
from common.models.geo import Settlement, Station, StationMajority, SettlementRelatedStations, Station2Settlement
from common.utils.settlement import get_main_stations
from travel.rasp.library.python.sitemap.models.common import CommonSitemap


class SettlementTransportSitemap(CommonSitemap):
    def items(self):
        # Для оптимизации сначала выполняем три больших запроса к базе данных
        # Выделяем города, с более чем 1 станцией, и только потом запускаем более медленную логику get_main_stations
        stations_counts = defaultdict(lambda: defaultdict(int))
        stations = Station.objects.filter(hidden=False, majority__id__lte=StationMajority.IN_TABLO_ID)
        for station in stations:
            stations_counts[station.t_type_id][station.settlement_id] += 1

        connected_stations = Station2Settlement.objects.all().prefetch_related('station')
        for connected in connected_stations:
            if not connected.station.hidden and connected.station.majority_id <= StationMajority.IN_TABLO_ID:
                stations_counts[connected.station.t_type_id][connected.settlement_id] += 1

        related_stations = SettlementRelatedStations.objects.all().prefetch_related('station')
        for related in related_stations:
            if not related.station.hidden and related.station.majority_id <= StationMajority.IN_TABLO_ID:
                stations_counts[related.station.t_type_id][related.settlement_id] += 1

        t_type_ids = [TransportType.PLANE_ID, TransportType.TRAIN_ID, TransportType.SUBURBAN_ID, TransportType.BUS_ID]
        t_types = list(TransportType.objects.filter(id__in=t_type_ids))

        result = []
        for settlement in Settlement.objects.all():
            for t_type in t_types:
                t_type_id = t_type.id if t_type.id != TransportType.SUBURBAN_ID else TransportType.TRAIN_ID
                if stations_counts[t_type_id][settlement.id] > 1 and len(get_main_stations(settlement, t_type.id)) > 1:
                    result.append((t_type.code, settlement.slug))
        return result

    def location(self, item):
        t_type_code, settlement_slug = item
        return '/{}/{}'.format(t_type_code, settlement_slug)
