import itertools
import logging
from dataclasses import dataclass, fields
from typing import Generator, Any, Iterable, Iterator, TypeVar, Callable

from yt.wrapper import TablePath
from yt.wrapper import YtClient

from travel.avia.ad_feed.ad_feed.environment import Environment

log = logging.getLogger(__name__)

MIN_PRICE_TABLE_FOR_ENVIRONMENT = {
    Environment.TESTING: '//home/avia/testing/data/price-index/{nv}/ad-feed',
    Environment.PRODUCTION: '//home/avia/data/price-index/{nv}/ad-feed',
}


@dataclass(frozen=True)
class MinPriceRow:
    departure_settlement_id: str
    departure_settlement_title: str
    departure_settlement_title_from: str
    arrival_settlement_id: str
    arrival_settlement_title: str
    arrival_settlement_title_to: str
    forward_date: str
    backward_date: str
    price: float
    currency: str
    search_url: str
    search_url_no_date: str
    route_url: str


@dataclass(frozen=True)
class MinPriceSettlementRow(MinPriceRow):
    price_departure_msk: float
    price_departure_spb: float


MIN_PRICE_ROW_COLUMNS = list(f.name for f in fields(MinPriceRow))

MSK_POINT_KEY = 213
SPB_POINT_KEY = 2


MinPriceReducer = Callable[[Any, Iterable[dict[str, Any]]], Iterator[dict[str, Any]]]


T = TypeVar('T')


def _same(value: T) -> Iterator[T]:
    yield value


def _reduce_by_arrival(_: Any, records: Iterable[dict[str, Any]]) -> Iterator[dict[str, Any]]:
    def get_min_price_by_departure(items: Iterable[dict[str, Any]], departure_point_key: int) -> float:
        filtered = filter(lambda r: r['departure_settlement_id'] == departure_point_key, items)
        row = min(filtered, key=lambda r: r['price'], default=None)
        return row['price'] if row else 0.0

    records, records_msk, records_spb = itertools.tee(records, 3)
    min_price_row = min(records, key=lambda r: r['price'])
    min_price_row['price_departure_msk'] = get_min_price_by_departure(records_msk, MSK_POINT_KEY)
    min_price_row['price_departure_spb'] = get_min_price_by_departure(records_spb, SPB_POINT_KEY)
    yield min_price_row  # don't need to convert to rubles because it's already converted


class MinPriceGetter:
    def __init__(self, client: YtClient, table: str):
        self.client = client
        self.table = table
        log.info('MinPrice.__init__(table=%s)', self.table)

    def iterate_min_prices(self) -> Generator[MinPriceRow, Any, None]:
        count = 0
        for row in self.client.read_table(TablePath(self.table, columns=MIN_PRICE_ROW_COLUMNS)):
            count += 1
            yield MinPriceRow(**row)
        log.info('Iterated through %d min prices', count)

    def iterate_min_settlement_prices(self) -> Iterator[MinPriceSettlementRow]:
        with self.client.TempTable() as temp:
            log.info('Running top stations map_reduce: %s', temp)
            self.client.run_map_reduce(
                mapper=_same,
                reducer=_reduce_by_arrival,
                source_table=[TablePath(self.table, columns=MIN_PRICE_ROW_COLUMNS)],
                destination_table=temp,
                sort_by=['arrival_settlement_id'],
                reduce_by=['arrival_settlement_id'],
            )
            log.info('Top stations collected')
            for row in self.client.read_table(temp):
                yield MinPriceSettlementRow(**row)
