from __future__ import absolute_import

from collections import defaultdict

from typing import Set, List  # noqa

from travel.avia.library.python.common.models.geo import Station2Settlement
from travel.avia.library.python.common.models.transport import TransportType

from travel.avia.backend.repository.station import StationRepository, station_repository, StationModel  # noqa
from travel.avia.backend.repository.settlement import SettlementRepository, settlement_repository, SettlementModel  # noqa


class GeoRelationsRepository(object):

    def __init__(self, settlement_repository, station_repository):
        # type: (SettlementRepository, StationRepository) -> None
        self._settlement_repository = settlement_repository
        self._station_repository = station_repository

        self._settlement_to_airport = {}
        self._airport_to_settlement = {}
        self._country_id_to_settlement_ids = {}

    def pre_cache(self):
        settlement_to_airport = defaultdict(set)
        airport_to_settlement = defaultdict(set)
        country_id_to_settlement_ids = defaultdict(set)

        for raw in Station2Settlement.objects.values('station_id', 'settlement_id'):
            airport_id = raw['station_id']
            settlement_id = raw['settlement_id']

            if self._settlement_repository.get(settlement_id) is None:
                continue

            if self._station_repository.get(airport_id, transport_type=TransportType.PLANE_ID) is None:
                continue

            settlement_to_airport[settlement_id].add(airport_id)
            airport_to_settlement[airport_id].add(settlement_id)

        for airport in self._station_repository.get_all(transport_type=TransportType.PLANE_ID):
            if airport.settlement_id is not None:
                settlement_to_airport[airport.settlement_id].add(airport.pk)
                airport_to_settlement[airport.pk].add(airport.settlement_id)

        for settlement_id in settlement_to_airport.keys():
            s = self._settlement_repository.get(settlement_id)
            if s is None or s.country_id is None:
                country_id_to_settlement_ids[s.country_id].add(s.pk)

        self._settlement_to_airport = dict(settlement_to_airport)
        self._airport_to_settlement = dict(airport_to_settlement)
        self._country_id_to_settlement_ids = dict(country_id_to_settlement_ids)

    def get_airport_ids_for(self, settlement_id):
        # type: (int) -> Set[int]
        return self._settlement_to_airport.get(settlement_id, set())

    def get_airports_for(self, settlement_id):
        # type: (int) -> List[StationModel]
        return [
            self._station_repository.get(airport_id, transport_type=TransportType.PLANE_ID)
            for airport_id in self._settlement_to_airport.get(
                settlement_id, set()
            )
        ]

    def get_settlement_ids_for(self, airport_id):
        # type: (int) -> Set[int]
        return self._airport_to_settlement.get(airport_id, set())

    def get_settlements_ids_with_airport_from(self, country_id):
        # type: (int) -> Set[SettlementModel]
        return self._country_id_to_settlement_ids.get(country_id, set())


geo_relations_repository = GeoRelationsRepository(
    settlement_repository=settlement_repository,
    station_repository=station_repository
)
