from yt.wrapper import TablePath, YtClient

from travel.avia.ad_feed.ad_feed.entities import SettlementId, StationId
from travel.avia.ad_feed.ad_feed.supplier.airport_info import convert_city_id, convert_station_id


class AirportBlacklist:

    def __init__(self, yt_client: YtClient, blacklist_table: str, stations_table: str):
        self.yt_client = yt_client
        self.blacklist_table = blacklist_table
        self.stations_table = stations_table
        self.settlement_blacklist: set[SettlementId] = set()
        self.station_blacklist: set[StationId] = set()
        self._update_blacklist()

    def contains_settlement(self, settlement_id: SettlementId) -> bool:
        return settlement_id in self.settlement_blacklist

    def contains_station(self, station_id: StationId) -> bool:
        return station_id in self.station_blacklist

    def _update_blacklist(self) -> None:
        blacklisted_countries: set[str] = set()
        blacklisted_cities: set[str] = set()

        for row in self.yt_client.read_table(TablePath(self.blacklist_table, columns=('region_type', 'region_id'))):
            region_type = row['region_type']
            if region_type == 'country':
                blacklisted_countries.add(row['region_id'])
            elif region_type == 'city':
                blacklisted_cities.add(row['region_id'])
            else:
                raise Exception(f'Region type is not supported: {row}')

        for row in self.yt_client.read_table(TablePath(
            self.stations_table,
            columns=('id', 'city_id', 'country_id', 't_type', 'hidden'),
        )):
            if row['t_type'] != 'plane' or row['hidden']:
                continue
            if row['country_id'] in blacklisted_countries or row['city_id'] in blacklisted_cities:
                city_id = row['city_id']
                if city_id:
                    settlement_id = convert_city_id(city_id)
                    self.settlement_blacklist.add(settlement_id)
                station_id = convert_station_id(row['id'])
                self.station_blacklist.add(station_id)
