import logging
from collections import defaultdict
from functools import cached_property
from typing import Optional, Any

from pydantic import BaseModel, ValidationError, root_validator
from yt.wrapper import YtClient, TablePath

from travel.avia.ad_feed.ad_feed.entities import StationId, SettlementId

logger = logging.getLogger(__name__)


class AirportInfo(BaseModel):
    title: str
    iata: Optional[str]
    sirena: Optional[str]
    icao: Optional[str]
    city_id: Optional[str]

    @root_validator()
    def check_at_least_one_code(cls, values: dict[str, Any]) -> dict[str, Any]:
        iata = values.get('iata')
        sirena = values.get('sirena')
        icao = values.get('icao')
        if all(not v for v in [iata, sirena, icao]):
            raise ValueError('At least one code must be filled')
        return values


class AirportInfoSupplier:
    def __init__(self, yt_client: YtClient, stations_table: str, station2settlement_table: str):
        self._yt_client = yt_client
        self._stations_table = stations_table
        self._station2settlement_table = station2settlement_table

    @cached_property
    def airports_info(self) -> dict[StationId, AirportInfo]:
        result = {}
        for row in self._yt_client.read_table(
            TablePath(
                self._stations_table, columns=('id', 'title', 'iata', 'sirena', 'icao', 't_type', 'city_id', 'hidden')
            )
        ):
            if row['t_type'] != 'plane' or row['hidden']:
                continue
            station_id = convert_station_id(row['id'])
            try:
                result[station_id] = AirportInfo.parse_obj(row)
            except ValidationError as e:
                logger.exception(e)
        return result

    @cached_property
    def settlement_to_station_mapping(self) -> dict[SettlementId, set[StationId]]:
        result = defaultdict(set)
        for i, info in self.airports_info.items():
            if info.city_id is None:
                continue
            result[convert_city_id(info.city_id)].add(i)
        for row in self._yt_client.read_table(
            TablePath(self._station2settlement_table, columns=('station_id', 'city_id'))
        ):
            station_id = convert_station_id(row['station_id'])
            settlement_id = convert_city_id(row['city_id'])
            if station_id in self.airports_info:
                result[settlement_id].add(station_id)
        return result


def convert_city_id(value: str) -> SettlementId:
    if not value.startswith('c'):
        raise Exception(f'Expected settlement id but got {value}')
    return SettlementId(int(value[1:]))


def convert_station_id(value: str) -> StationId:
    if not value.startswith('s'):
        raise Exception(f'Expected station id but got {value}')
    return StationId(int(value[1:]))
