import logging
from collections import namedtuple
from datetime import timedelta, date, datetime
from functools import partial, cached_property
from typing import Optional, Dict, Any, Iterable, Iterator, Generator

from yt.wrapper import TablePath
from yt.wrapper import YtClient

from travel.avia.ad_feed.ad_feed.entities import Direction
from travel.avia.ad_feed.ad_feed.environment import Environment
from travel.avia.library.python.lib_yt.cache import Station
from travel.avia.library.python.lib_yt.tables import safe_tables_for_date_range

log = logging.getLogger(__name__)

AVIA_USER_SEARCH_LOG_DIR = '//home/avia/logs/avia-users-search-log'

OUTPUT_TABLE_FOR_ENVIRONMENT = {
    Environment.TESTING: '//home/avia/testing/data/ad-feed/top-directions',
    Environment.PRODUCTION: '//home/avia/data/ad-feed/top-directions',
}


def _settlement_or_none(stations_by_id: Dict, code: str) -> Optional[int]:
    if not code:
        return None
    if code[0] == 'c':
        return int(code[1:])
    if code[0] == 's':
        station = stations_by_id.get(code[1:])
        if not station:
            return None
        city_code = station['city_id']
        if not city_code:
            return None
        return int(city_code[1:])
    return None


def _top_directions_mapper(
    stations_by_id: dict[str, Any], nv: str, row: dict[str, Any]
) -> Generator[dict[str, Any], Any, None]:
    if row['national_version'] != nv:
        return

    from_id: Optional[int] = _settlement_or_none(stations_by_id, row['fromId'])
    if from_id is None:
        return
    to_id: Optional[int] = _settlement_or_none(stations_by_id, row['toId'])
    if to_id is None:
        return
    yield {
        'fromId': from_id,
        'toId': to_id,
    }


def _top_directions_reducer(key: dict[str, Any], records: Iterable[Any]) -> Generator[dict[str, Any], Any, None]:
    yield {
        'fromId': key['fromId'],
        'toId': key['toId'],
        'count': sum(1 for _ in records),
    }


TopDirectionRow = namedtuple('TopDirectionRow', 'fromId toId count')


class TopDirections:
    def __init__(
        self, client: YtClient, stations: Station, search_log_dir: str, records_since: datetime, national_version: str
    ):
        self.client: YtClient = client
        self.stations = stations
        self.search_log_dir = search_log_dir
        self.records_since = records_since
        self.national_version = national_version

        self._cache: Dict[Direction, int] = dict()

    def get_top_directions(self) -> Dict[Direction, int]:
        if not self._cache:
            self._cache = self._get_top_directions()
            log.info('Populated top directions cache with %d items', len(self._cache))
        return self._cache

    def _get_top_directions(self) -> Dict[Direction, int]:
        with self.client.TempTable() as temp:
            stations_by_id = self.stations.by_id
            log.info('Running top directions map-reduce: %s', temp)
            self.client.run_map_reduce(
                mapper=partial(_top_directions_mapper, stations_by_id, self.national_version),
                reducer=_top_directions_reducer,
                source_table=[
                    TablePath(tbl, columns=('fromId', 'toId', 'national_version'))
                    for tbl in safe_tables_for_date_range(
                        self.client,
                        self.search_log_dir,
                        self.records_since.date(),
                        date.today() + timedelta(days=1),
                    )
                ],
                destination_table=temp,
                sort_by=['fromId', 'toId'],
                reduce_by=['fromId', 'toId'],
            )
            log.info('Done top directions map-reduce')
            return {
                Direction(from_id=row['fromId'], to_id=row['toId']): row['count']
                for row in self.client.read_table(temp)
            }


class TopDirectionsDumper:
    def __init__(self, client: YtClient, top_directions: TopDirections, output_table: str):
        self.client = client
        self.top_directions = top_directions
        self.output_table = output_table

    def dump(self) -> None:
        with self.client.Transaction():
            self.client.write_table(
                TablePath(
                    self.output_table,
                    schema=[
                        {'name': 'fromId', 'type': 'int64'},
                        {'name': 'toId', 'type': 'int64'},
                        {'name': 'count', 'type': 'uint64'},
                    ],
                ),
                (r._asdict() for r in self.top_directions.get_top_directions()),
            )
            log.info('Top directions dumped to %s', self.output_table)


def _top_stations_mapper(stations_by_id: Dict, nv: str, row: dict[str, Any]) -> Iterator[dict[str, int]]:
    if row['national_version'] != nv:
        return

    to_id: Optional[int] = _settlement_or_none(stations_by_id, row['toId'])
    if to_id is None:
        return
    yield {'toId': to_id}


def _top_stations_reducer(key: dict[str, Any], records: Iterable[Any]) -> Iterator[dict[str, Any]]:
    yield {
        'toId': key['toId'],
        'count': sum(1 for _ in records),
    }


class StationsPopularityGetter:
    def __init__(
        self, client: YtClient, stations: Station, search_log_dir: str, records_since: datetime, national_version: str
    ):
        self.client: YtClient = client
        self.stations = stations
        self.search_log_dir = search_log_dir
        self.records_since = records_since
        self.national_version = national_version

    @cached_property
    def stations_popularity(self) -> dict[int, int]:
        with self.client.TempTable() as temp:
            stations_by_id = self.stations.by_id
            log.info('Running top stations map-reduce: %s', temp)
            self.client.run_map_reduce(
                mapper=partial(_top_stations_mapper, stations_by_id, self.national_version),
                reducer=_top_stations_reducer,
                source_table=[
                    TablePath(tbl, columns=('toId', 'national_version'))
                    for tbl in safe_tables_for_date_range(
                        self.client,
                        self.search_log_dir,
                        self.records_since.date(),
                        date.today() + timedelta(days=1),
                    )
                ],
                destination_table=temp,
                sort_by=['toId'],
                reduce_by=['toId'],
            )
            log.info('Top stations collected')
            return {row['toId']: row['count'] for row in self.client.read_table(temp)}
