import logging
from collections import defaultdict
from typing import Optional, Iterator

import itertools
from pydantic import BaseModel, Field, ValidationError

from travel.avia.ad_feed.ad_feed.airport_blacklist import AirportBlacklist
from travel.avia.ad_feed.ad_feed.click_price import ClickPriceCounter
from travel.avia.ad_feed.ad_feed.direction_flights import FlightsCounter
from travel.avia.ad_feed.ad_feed.direction_type import DirectionTypeResolver
from travel.avia.ad_feed.ad_feed.entities import Direction
from travel.avia.ad_feed.ad_feed.entities import SettlementId, StationId
from travel.avia.ad_feed.ad_feed.environment import Environment
from travel.avia.ad_feed.ad_feed.feed_generator.abstract import IFeedGenerator
from travel.avia.ad_feed.ad_feed.feed_generator.const import DEFAULT_IMAGES
from travel.avia.ad_feed.ad_feed.min_price import MinPriceGetter
from travel.avia.ad_feed.ad_feed.supplier import AirportInfoSupplier
from travel.avia.ad_feed.ad_feed.top_directions import TopDirections
from travel.avia.library.python.lib_yt.cache import SettlementBigImage
from travel.avia.library.python.shared_dicts.cache.settlement_cache import SettlementCache

log = logging.getLogger(__name__)


class DirectionFeedRow(BaseModel):
    departure_settlement_id: SettlementId = Field(yt_type='int64')
    departure_settlement_title: str = Field(yt_type='string')
    departure_settlement_title_from: str = Field(yt_type='string')
    arrival_settlement_id: SettlementId = Field(yt_type='int64')
    arrival_settlement_title: str = Field(yt_type='string')
    arrival_settlement_title_to: str = Field(yt_type='string')
    forward_date: str = Field(yt_type='string')
    backward_date: Optional[str] = Field(yt_type='string')
    price: float = Field(yt_type='double')
    currency: str = Field(yt_type='string')
    search_url: str = Field(yt_type='string')
    search_url_no_date: str = Field(yt_type='string')
    route_url: Optional[str] = Field(yt_type='string')
    click_price: Optional[float] = Field(yt_type='double')
    flights_count: Optional[int] = Field(yt_type='uint64')
    direction_type: str = Field(yt_type='string')
    popularity: Optional[int] = Field(yt_type='uint64')
    departure_settlement_image: str = Field(yt_type='string')
    arrival_settlement_image: str = Field(yt_type='string')
    has_airport: bool = Field(yt_type='boolean')


OUTPUT_TABLE_FOR_ENVIRONMENT = {
    Environment.TESTING: '//home/avia/testing/data/ad-feed/ru/feed',
    Environment.PRODUCTION: '//home/avia/data/ad-feed/ru/feed',
}


def _get_airport_to_city_mapping(
    city_to_airports: dict[SettlementId, set[StationId]]
) -> dict[StationId, set[SettlementId]]:
    result = defaultdict(set)
    for city, airports in city_to_airports.items():
        for airport in airports:
            result[airport].add(city)
    return result


def _get_neighbours(
    city_id: SettlementId,
    airport_to_cities: dict[StationId, set[SettlementId]],
    city_to_airports: dict[SettlementId, set[StationId]],
) -> set[SettlementId]:
    airports = city_to_airports[city_id]
    result = set()
    for airport in airports:
        result |= airport_to_cities[airport]
    return result


class DirectionFeedGenerator(IFeedGenerator[DirectionFeedRow]):
    def __init__(
        self,
        min_prices: MinPriceGetter,
        click_price: ClickPriceCounter,
        direction_flights: FlightsCounter,
        direction_type: DirectionTypeResolver,
        top_directions: TopDirections,
        images: SettlementBigImage,
        airport_info_supplier: AirportInfoSupplier,
        settlement_info_supplier: SettlementCache,
        airport_blacklist: AirportBlacklist,
    ):
        self.min_prices = min_prices
        self.click_price = click_price
        self.direction_flights = direction_flights
        self.direction_type = direction_type
        self.top_directions = top_directions
        self.images = images
        self._airport_info_supplier = airport_info_supplier
        self._settlement_info_supplier = settlement_info_supplier
        self._airport_blacklist = airport_blacklist

    def generate_feed(self) -> Iterator[DirectionFeedRow]:
        log.info('Fetching average click price')
        click_prices = self.click_price.avg_click_price_by_direction
        log.info('Fetching direction flights count')
        direction_flights = self.direction_flights.flights_by_direction
        log.info('Fetching top directions')
        top_directions = self.top_directions.get_top_directions()
        airport_to_city_mapping = _get_airport_to_city_mapping(
            self._airport_info_supplier.settlement_to_station_mapping
        )
        self._settlement_info_supplier.populate()

        log.info('Iterating through min prices')
        row_by_direction: dict[tuple[SettlementId, SettlementId], DirectionFeedRow] = dict()
        for min_price_row in self.min_prices.iterate_min_prices():
            from_id = min_price_row.departure_settlement_id
            to_id = min_price_row.arrival_settlement_id
            d = Direction(from_id, to_id)

            departure_neighbours = _get_neighbours(
                SettlementId(int(from_id)),
                city_to_airports=self._airport_info_supplier.settlement_to_station_mapping,
                airport_to_cities=airport_to_city_mapping,
            )
            arrival_neighbours = _get_neighbours(
                SettlementId(int(to_id)),
                city_to_airports=self._airport_info_supplier.settlement_to_station_mapping,
                airport_to_cities=airport_to_city_mapping,
            )
            for departure_id, arrival_id in itertools.product(departure_neighbours, arrival_neighbours):
                try:
                    if self._airport_blacklist.contains_settlement(departure_id):
                        log.info(f'Skipping row by {departure_id=}')
                        continue
                    if self._airport_blacklist.contains_settlement(arrival_id):
                        log.info(f'Skipping row by {arrival_id=}')
                        continue

                    departure_city = self._settlement_info_supplier.get_settlement_by_id(departure_id)
                    arrival_city = self._settlement_info_supplier.get_settlement_by_id(arrival_id)
                    if arrival_city is None:
                        log.error(
                            f"No city for id {arrival_id},"
                            f" direction {min_price_row.departure_settlement_title} - {min_price_row.arrival_settlement_title}"  # type: ignore
                        )
                        continue
                    if departure_city is None:
                        log.error(
                            f"No city for id {departure_id}, "
                            f"direction {min_price_row.departure_settlement_title} - {min_price_row.arrival_settlement_title}"  # type: ignore
                        )
                        continue

                    row = DirectionFeedRow(
                        departure_settlement_id=departure_id,
                        departure_settlement_title=departure_city.TitleDefault,  # type: ignore
                        departure_settlement_title_from=departure_city.Title.Ru.Genitive or departure_city.TitleDefault,  # type: ignore
                        arrival_settlement_id=arrival_id,
                        arrival_settlement_title=arrival_city.TitleDefault,  # type: ignore
                        arrival_settlement_title_to=arrival_city.Title.Ru.Accusative or arrival_city.TitleDefault,  # type: ignore
                        forward_date=min_price_row.forward_date,  # type: ignore
                        backward_date=min_price_row.backward_date,  # type: ignore
                        price=min_price_row.price,  # type: ignore
                        currency=min_price_row.currency,  # type: ignore
                        search_url=min_price_row.search_url,  # type: ignore
                        search_url_no_date=min_price_row.search_url_no_date,  # type: ignore
                        route_url=min_price_row.route_url,  # type: ignore
                        click_price=click_prices.get(d).price_avg if d in click_prices else None,  # type: ignore
                        flights_count=direction_flights.get(d),
                        direction_type=str(self.direction_type.for_settlement_pair(
                            SettlementId(int(from_id)), SettlementId(int(to_id))
                        )),
                        popularity=top_directions.get(d),
                        departure_settlement_image=self.images.get_marketing_square(
                            min_price_row.departure_settlement_id,  # type: ignore
                            DEFAULT_IMAGES,
                        ),
                        arrival_settlement_image=self.images.get_marketing_square(
                            min_price_row.arrival_settlement_id,  # type: ignore
                            DEFAULT_IMAGES,
                        ),
                        has_airport=(departure_id == from_id and arrival_id == to_id),
                    )
                    key = (departure_id, arrival_id)
                    if key not in row_by_direction or row.price < row_by_direction[key].price:
                        row_by_direction[key] = row
                except ValidationError as e:
                    log.exception(e)

        yield from (value for value in row_by_direction.values())
