import datetime
import logging
from typing import Iterator, Any, Iterable

import more_itertools
from yt.wrapper import YtClient, TablePath

from travel.avia.avia_statistics.updaters.recent_direction_popularity.lib.table import (
    DirectionPopularityTable,
    DirectionPopularity,
)
from travel.avia.library.python.lib_yt.tables import safe_tables_for_date_range

logger = logging.getLogger(__name__)


def _mapper(row: dict[str, Any]) -> Iterator[dict[str, int]]:
    if row['FILTER'] != 0:
        return
    departure = row['_REST']['settlement_from_id']
    arrival = row['_REST']['settlement_to_id']
    if departure is None or arrival is None:
        return
    yield {'from_id': departure, 'to_id': arrival}


def _reducer(key: dict[str, int], records: Iterable[Any]) -> Iterator[dict[str, Any]]:
    yield {'from_id': key['from_id'], 'to_id': key['to_id'], 'redir_number': sum(1 for _ in records)}


def get_popularity(date_from: datetime.date, date_to: datetime.date, yt_client: YtClient) -> dict[int, int]:
    with yt_client.TempTable() as temp:
        yt_client.run_map_reduce(
            mapper=_mapper,
            reducer=_reducer,
            source_table=[
                TablePath(tbl, columns=('_REST', 'FILTER'))
                for tbl in safe_tables_for_date_range(
                    yt_client,
                    '//home/avia/logs/avia-redir-balance-by-day-log',
                    date_from,
                    date_to,
                )
            ],
            destination_table=temp,
            reduce_by=['from_id', 'to_id'],
        )
        return {(row['from_id'], row['to_id']): row['redir_number'] for row in yt_client.read_table(temp)}


class Updater:
    def __init__(self, yt_client: YtClient, storage: DirectionPopularityTable):
        self._yt_client = yt_client
        self._storage = storage

    def update(self, days: int) -> None:
        redir_number = get_popularity(
            date_from=datetime.date.today() - datetime.timedelta(days=days + 1),
            date_to=datetime.date.today() - datetime.timedelta(days=1),
            yt_client=self._yt_client,
        )
        for batch in more_itertools.chunked(redir_number.items(), 1000):
            cur = [
                DirectionPopularity(settlement_from_id=from_id, settlement_to_id=to_id, redir_number=number)
                for (from_id, to_id), number in batch
            ]
            self._storage.replace_batch(cur)
        logger.info(f"{self._storage.count()} in table")
