import logging
from typing import Iterator, cast

import itertools
from pydantic import Field

from travel.avia.ad_feed.ad_feed.airport_blacklist import AirportBlacklist
from travel.avia.ad_feed.ad_feed.feed_generator.abstract import IFeedGenerator
from travel.avia.ad_feed.ad_feed.feed_generator.direction import DirectionFeedRow
from travel.avia.ad_feed.ad_feed.supplier.airport_info import AirportInfo, AirportInfoSupplier

logger = logging.getLogger(__name__)


class StationsAndSettlementsRow(DirectionFeedRow):
    departure_airport_code: str = Field(yt_type='string')
    departure_airport_title: str = Field(yt_type='string')
    arrival_airport_code: str = Field(yt_type='string')
    arrival_airport_title: str = Field(yt_type='string')


def _get_code(info: AirportInfo) -> str:
    return cast(str, info.iata or info.sirena or info.icao)


class StationsAndSettlementsFeedGenerator(IFeedGenerator[StationsAndSettlementsRow]):
    def __init__(
        self,
        directional_generator: IFeedGenerator[DirectionFeedRow],
        airport_info_supplier: AirportInfoSupplier,
        airport_blacklist: AirportBlacklist,
    ):
        self._directional_generator = directional_generator
        self._airport_info_supplier = airport_info_supplier
        self._airport_blacklist = airport_blacklist

    def generate_feed(self) -> Iterator[StationsAndSettlementsRow]:
        settlement_to_station_mapping = self._airport_info_supplier.settlement_to_station_mapping
        airports_info = self._airport_info_supplier.airports_info
        for row in self._directional_generator.generate_feed():
            if self._airport_blacklist.contains_settlement(row.departure_settlement_id):
                logger.info(f'Skipping row by {row.departure_settlement_id=}')
                continue
            if self._airport_blacklist.contains_settlement(row.arrival_settlement_id):
                logger.info(f'Skipping row by {row.arrival_settlement_id=}')
                continue

            departure_stations_ids = settlement_to_station_mapping[row.departure_settlement_id]
            arrival_stations_ids = settlement_to_station_mapping[row.arrival_settlement_id]
            for departure_station_id, arrival_station_id in itertools.product(
                departure_stations_ids, arrival_stations_ids
            ):
                departure_station_info = airports_info[departure_station_id]
                arrival_station_info = airports_info[arrival_station_id]
                yield StationsAndSettlementsRow(
                    **row.dict(),
                    departure_airport_code=_get_code(departure_station_info),
                    departure_airport_title=departure_station_info.title,
                    arrival_airport_code=_get_code(arrival_station_info),
                    arrival_airport_title=arrival_station_info.title,
                )
