# -*- coding: utf-8 -*-
from typing import Set

from travel.avia.library.python.django_namedtuples.queryset import ModelInterface

from travel.avia.library.python.common.models.transport import TransportType

from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.country import get_country_by_id
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.settlement import get_settlement_by_id
from travel.avia.ticket_daemon.ticket_daemon.api.models_utils.station import (
    get_airport_ids_by_settlement_id, get_settlement_ids_by_airport_id
)


POINT_TYPE_STATION = 1
POINT_TYPE_SETTLEMENT = 2


class PointInterface(ModelInterface):
    def get_related_country(self):
        raise NotImplementedError

    def get_related_settlement(self):
        # type: () -> SettlementInterface
        raise NotImplementedError

    def get_related_settlement_ids(self):
        # type: () -> Set[int]
        raise NotImplementedError

    def get_allowed_airports_ids(self):
        raise NotImplementedError

    def get_point_type(self):
        raise NotImplementedError


class SettlementInterface(PointInterface):
    _fields = ('id', 'country_id')

    def get_point_type(self):
        return POINT_TYPE_SETTLEMENT

    def get_related_country(self):
        if self.country_id:
            return get_country_by_id(self.country_id)

    def get_related_settlement(self):
        # type: () -> SettlementInterface
        return self

    def get_related_settlement_ids(self):
        # type: () -> Set[int]
        ids = {self.id}
        for airport_id in self.get_allowed_airports_ids():
            ids.update(get_settlement_ids_by_airport_id(airport_id))
        return ids

    def get_allowed_airports_ids(self):
        return get_airport_ids_by_settlement_id(self.id)


class StationInterface(PointInterface):
    _fields = ('id', 'country_id', 'settlement_id', 't_type_id')

    def get_point_type(self):
        return POINT_TYPE_STATION

    def get_related_country(self):
        if self.country_id:
            return get_country_by_id(self.country_id)
        settlement = self.get_related_settlement()
        if settlement:
            return settlement.get_related_country()

    def get_related_settlement(self):
        # type: () -> SettlementInterface
        return get_settlement_by_id(
            self.settlement_id or
            next(iter(get_settlement_ids_by_airport_id(self.id) or []), None)
        )

    def get_related_settlement_ids(self):
        # type: () -> Set[int]
        ids = get_settlement_ids_by_airport_id(self.id)
        return set(list(ids)) if ids else set()

    def get_allowed_airports_ids(self):
        assert self.t_type_id == TransportType.PLANE_ID, "Point is not an airport"
        return {self.id}
