import logging
from collections import defaultdict
from functools import cached_property
from typing import Dict, Optional

from travel.avia.ad_feed.ad_feed.entities import Direction
from travel.avia.ad_feed.ad_feed.environment import Environment
from travel.avia.library.python.lib_yt.cache import Station
from travel.avia.library.python.shared_flights_client.client import SharedFlightsClient

log = logging.getLogger(__name__)

OUTPUT_TABLE_FOR_ENVIRONMENT = {
    Environment.TESTING: '//home/avia/testing/data/ad-feed/direction-flights',
    Environment.PRODUCTION: '//home/avia/data/ad-feed/direction-flights',
}


class FlightsCounter:
    def __init__(self, shared_flights_client: SharedFlightsClient, stations: Station):
        self.client = shared_flights_client
        self.stations = stations

    @cached_property
    def flights_by_direction(self) -> Dict[Direction, int]:
        log.info('Getting summary from shared flights')
        stations_summary = self.client.flight_p2p_summary()

        direction_frequency: dict[Direction, int] = dict()

        log.info('Squashing stations to settlements')
        for record in stations_summary['flights']:
            station_from = record['departureStation']
            station_to = record['arrivalStation']
            total_flights_count = record['totalFlightsCount']

            settlement_from = self._station_id_to_settlement_id(station_from)
            settlement_to = self._station_id_to_settlement_id(station_to)

            if not (settlement_from and settlement_to):
                continue

            direction = Direction(settlement_from, settlement_to)

            direction_frequency[direction] = direction_frequency.get(direction, 0) + total_flights_count
        return direction_frequency

    def _station_id_to_settlement_id(self, station_id: int) -> Optional[int]:
        station = self.stations.by_id.get(station_id)
        if not station:
            return None
        city_key = station['city_id']
        if not city_key:
            return None
        return int(city_key[1:])

    @cached_property
    def flights_by_arrival_settlement(self) -> dict[int, int]:
        log.info('Getting summary from shared flights')
        stations_summary = self.client.flight_p2p_summary()

        station_frequency: dict[int, int] = defaultdict(int)

        log.info('Squashing stations to settlements')
        for record in stations_summary['flights']:
            station_to = record['arrivalStation']
            total_flights_count = record['totalFlightsCount']

            settlement_to = self._station_id_to_settlement_id(station_to)
            if settlement_to is None:
                continue

            station_frequency[settlement_to] += total_flights_count
        return station_frequency
