import logging
from collections import namedtuple
from datetime import date, timedelta, datetime
from functools import partial
from typing import Generator, Any, Dict, Optional, Iterable

from travel.avia.library.python.lib_yt.cache import Station
from travel.avia.library.python.lib_yt.tables import safe_tables_for_date_range
from yt.wrapper import YtClient
from yt.wrapper.ypath import TablePath

log = logging.getLogger(__name__)

REDIR_BALANCE_LOG_DIR = '//home/avia/logs/avia-redir-balance-by-day-log'

RedirLogRow = namedtuple('RedirLogRow', 'from_id to_id marker price_total price_avg count')


def _settlement_or_none(stations_by_id: Dict, code: str) -> Optional[int]:
    if not code:
        return None
    if code[0] == 'c':
        return int(code[1:])
    if code[0] == 's':
        station = stations_by_id.get(code[1:])
        if not station:
            return None
        city_code = station['city_id']
        if not city_code:
            return None
        return int(city_code[1:])
    return None


def _click_price_mapper(
    stations_by_id: Dict, after: datetime, nv: str, record: dict[str, Any]
) -> Generator[dict[str, Any], Any, None]:
    if record['FILTER'] != 0:
        return
    if record['NATIONAL'] != nv:
        return
    if after > datetime.fromtimestamp(record['UNIXTIME']):
        return
    from_id: Optional[int] = _settlement_or_none(stations_by_id, record['FROMID'])
    if from_id is None:
        return
    to_id: Optional[int] = _settlement_or_none(stations_by_id, record['TOID'])
    if to_id is None:
        return

    yield {
        'from': from_id,
        'to': to_id,
        'marker': record['MARKER'],
        'price': record['PRICE'],
    }


def _click_price_reducer(key: dict, records: Iterable[dict[str, Any]]) -> Generator[dict[str, Any], Any, None]:
    counter = 0
    price = 0.0
    for record in records:
        price += record['price']
        counter += 1
    yield {
        'from': key['from'],
        'to': key['to'],
        'marker': key['marker'],
        'price_total': price,
        'price_avg': price / counter,
        'count': counter,
    }


class RedirBalanceLog:
    def __init__(self, ytc: YtClient, stations: Station, log_dir: str, recrods_since: datetime, national_version: str):
        self.client = ytc
        self.stations = stations
        self.log_dir = log_dir
        self.records_since = recrods_since
        self.national_version = national_version
        log.debug(
            'RedirBalanceLog.__init__(log_dir=%s, recrods_since=%s, national_version=%s)',
            log_dir,
            recrods_since,
            national_version,
        )

    def get_redir_balance_log(self) -> Generator[RedirLogRow, Any, None]:
        with self.client.TempTable() as temp:
            stations_by_id = self.stations.by_id
            log.info('Running redir-balance-log map-reduce: %s', temp)
            self.client.run_map_reduce(
                mapper=partial(_click_price_mapper, stations_by_id, self.records_since, self.national_version),
                reducer=_click_price_reducer,
                source_table=[
                    TablePath(tbl, columns=('FROMID', 'TOID', 'MARKER', 'PRICE', 'UNIXTIME', 'FILTER', 'NATIONAL'))
                    for tbl in safe_tables_for_date_range(
                        self.client,
                        self.log_dir,
                        self.records_since.date() - timedelta(days=1),
                        date.today() + timedelta(days=1),
                    )
                ],
                destination_table=temp,
                sort_by=['from', 'to', 'marker'],
                reduce_by=['from', 'to', 'marker'],
            )
            log.info('Done redir-balance-log map-reduce: %s', temp)

            count = 0
            for row in self.client.read_table(temp):
                count += 1
                yield RedirLogRow(
                    from_id=row['from'],
                    to_id=row['to'],
                    marker=row['marker'],
                    price_total=row['price_total'],
                    price_avg=row['price_avg'],
                    count=row['count'],
                )
            log.info('Iterated through %d redir-balance-log rows', count)
