# -*- coding: utf-8 -*-
import logging
from typing import Set
from datetime import date, datetime, timedelta

from travel.avia.avia_statistics.min_prices_by_airline_updater.lib.table import (
    MinPricesByAirlineTable, MinPriceByAirline,
)
from travel.avia.avia_statistics.min_prices_by_airline_updater.lib.collector import MinPricesByAirlineCollector
from travel.avia.avia_statistics.landing_routes import LandingRoute
from travel.avia.avia_statistics.lib.consts import CURRENCY_BY_NATIONAL_VERSION

logger = logging.getLogger(__name__)


class MinPricesByAirlineUpdater(object):
    TTL_TIMEDELTA = timedelta(days=1, hours=6)

    def __init__(self, landing_routes, min_prices_by_airline_collector, min_prices_by_airline_table, batch_size):
        # type: (Set[LandingRoute], MinPricesByAirlineCollector, MinPricesByAirlineTable, int) -> None
        self._landing_routes = landing_routes
        self._min_prices_by_airline_collector = min_prices_by_airline_collector
        self._min_prices_by_airline_table = min_prices_by_airline_table
        self._batch_size = batch_size

    def update(self):
        self._min_prices_by_airline_table.create_if_doesnt_exist()

        processed = 0
        batch = []
        for row in self._min_prices_by_airline_collector.collect():
            (
                from_id,
                to_id,
                nv,
                airline_id,
                min_price,
                departure_date,
                min_price_with_transfers,
                departure_date_with_transfers,
                currency,
            ) = row
            if (from_id, to_id, nv) not in self._landing_routes or currency != CURRENCY_BY_NATIONAL_VERSION.get(nv):
                continue
            expires_at = datetime.utcnow() + self.TTL_TIMEDELTA
            min_price_by_airline = MinPriceByAirline(
                from_id,
                to_id,
                nv,
                airline_id,
                min_price,
                departure_date,
                min_price_with_transfers,
                departure_date_with_transfers,
                currency,
                expires_at,
            )
            batch.append(min_price_by_airline)
            if len(batch) == self._batch_size:
                self._min_prices_by_airline_table.replace_batch(batch)
                processed += len(batch)
                batch = []
                logger.info('Processed: %d routes', processed)
        if len(batch):
            self._min_prices_by_airline_table.replace_batch(batch)
            processed += len(batch)
            logger.info('Processed: %d routes', processed)
        logger.info('All routes were stored into YDB')

        logger.info('Removing records with departure date equal or older than %s', date.today().strftime('%Y-%m-%d'))
        self._min_prices_by_airline_table.delete_old(date.today())
