# -*- coding: utf-8 -*-
import logging
from datetime import date, datetime, timedelta

from travel.avia.avia_statistics.popular_months_by_route_updater.lib.table import (
    PopularMonthWithPriceByRoute, PopularMonthWithPriceTable
)
from travel.avia.avia_statistics.popular_months_by_route_updater.lib.collector import PopularMonthsCollector
from travel.avia.avia_statistics.lib.consts import CURRENCY_BY_NATIONAL_VERSION

logger = logging.getLogger(__name__)


class PopularMonthsWithPriceUpdater(object):
    TTL_TIMEDELTA = timedelta(days=1, hours=6)

    def __init__(
        self,
        popular_months_collector,
        popular_month_with_price_table,
        batch_size=1000,
    ):
        # type: (PopularMonthsCollector, PopularMonthWithPriceTable, int) -> None
        self._popular_months_collector = popular_months_collector
        self._popular_month_with_price_table = popular_month_with_price_table
        self._batch_size = batch_size

    def update_popular_months(self, today):
        # type: (date) -> None
        self._popular_month_with_price_table.create_if_doesnt_exist()
        batch = []
        total_processed = 0
        logger.info('start collecting popular months')
        for route_pop_month_price in self._popular_months_collector.collect_popular_months_to_yt_table(today):
            batch.append(route_pop_month_price)
            if len(batch) == self._batch_size:
                self._process_batch(batch)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
                batch = []
        if batch:
            self._process_batch(batch)
            total_processed += len(batch)
            logger.info('processed: %s', total_processed)
        logger.info('all %s popular months were stored into YDB', total_processed)

    def _process_batch(self, batch):
        today = date.today()

        def map_to_record(route_with_popular_month_price):
            (
                from_id,
                to_id,
                national_version,
                median_price,
                year,
                month,
            ) = route_with_popular_month_price

            if (year, month) < (today.year, today.month):
                year += 1
            currency = CURRENCY_BY_NATIONAL_VERSION.get(national_version)
            expires_at = datetime.utcnow() + self.TTL_TIMEDELTA
            return PopularMonthWithPriceByRoute(
                from_id=from_id,
                to_id=to_id,
                national_version=national_version,
                year=year,
                month=month,
                price=median_price,
                currency=currency,
                expires_at=expires_at,
            )

        month_with_price_records = list(filter(None, map(map_to_record, batch)))
        self._popular_month_with_price_table.replace_batch(month_with_price_records)
