# -*- coding: utf-8 -*-
import logging
from datetime import date, datetime, timedelta
from typing import Optional, Tuple

from travel.avia.avia_statistics.median_prices_updater.lib.table import (
    MedianPrices, MedianPricesTable
)
from travel.avia.avia_statistics.median_prices_updater.lib.collector import MedianPricesCollector

logger = logging.getLogger(__name__)


class MedianPricesUpdater(object):
    TTL_TIMEDELTA = timedelta(days=1, hours=6)

    def __init__(
        self,
        median_prices_table,
        median_prices_collector,
        batch_size=2000,
    ):
        # type: (MedianPricesTable, MedianPricesCollector, int) -> None
        self._median_prices_table = median_prices_table
        self._median_prices_collector = median_prices_collector
        self._batch_size = batch_size

    def update(self, today):
        # type: (date) -> None
        self._median_prices_table.create_if_doesnt_exist()
        batch = []
        total_processed = 0
        for row in self._median_prices_collector.collect_median_prices(today):
            batch.append(row)
            if len(batch) == self._batch_size:
                self._process_batch(batch)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
                batch = []
        if batch:
            self._process_batch(batch)
            total_processed += len(batch)
            logger.info('processed: %s', total_processed)
        logger.info('all %s median prices were stored to YDB', total_processed)

    def _process_batch(self, batch):
        def map_to_median_prices(row):
            # type: (Tuple) -> Optional[MedianPrices]
            (
                from_id,
                to_id,
                national_version,
                year_median_price,
                month_median_price,
                year,
                month,
                currency
            ) = row
            expires_at = datetime.utcnow() + self.TTL_TIMEDELTA
            return MedianPrices(
                from_id=from_id,
                to_id=to_id,
                national_version=national_version,
                year=year,
                month=month,
                month_median_price=month_median_price,
                year_median_price=year_median_price,
                currency=currency,
                expires_at=expires_at,
            )

        median_prices = list(map(map_to_median_prices, batch))
        self._median_prices_table.replace_batch(median_prices)
