# -*- coding: utf-8 -*-
import logging
from datetime import date, timedelta, datetime
from collections import defaultdict
from typing import List, Set, Optional

from travel.avia.avia_statistics.return_ticket_prices_updater.lib.table import (
    ReturnTicketPrice, ReturnTicketPriceTable
)
from travel.avia.avia_statistics.landing_routes import LandingRoute
from travel.avia.avia_statistics.services.price_index.client import Client as PriceIndexClient

logger = logging.getLogger(__name__)


class ReturnTicketPricesUpdater(object):
    TTL_TIMEDELTA = timedelta(days=1, hours=6)

    def __init__(
        self,
        landing_routes,
        return_ticket_price_table,
        price_index_client,
        price_index_batch_size=100,
        price_index_timeout=5,
    ):
        # type: (Set[LandingRoute], ReturnTicketPriceTable, PriceIndexClient, int, int) -> None
        self._landing_routes = landing_routes
        self._return_ticket_price_table = return_ticket_price_table
        self._price_index_client = price_index_client
        self._price_index_batch_size = price_index_batch_size
        self._price_index_timeout = price_index_timeout

    def update_return_ticket_prices(self, today, window_size=30):
        # type: (date, int) -> None
        self._return_ticket_price_table.create_if_doesnt_exist()
        window_size //= 2
        request_date = today + timedelta(days=window_size)
        batch_by_national_version = defaultdict(list)
        total_processed = 0
        for route in self._landing_routes:
            if LandingRoute(route.to_id, route.from_id, route.national_version) not in self._landing_routes:
                continue
            batch = batch_by_national_version[route.national_version]
            batch.append(route)
            if len(batch) == self._price_index_batch_size:
                self._process_batch(request_date, window_size, route.national_version, batch)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
                batch_by_national_version[route.national_version] = []
        for national_version in batch_by_national_version:
            batch = batch_by_national_version[national_version]
            if batch:
                self._process_batch(request_date, window_size, national_version, batch)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
        logger.info('all %s return ticket prices were stored into YDB', total_processed)

    def _process_batch(self, request_date, window_size, national_version, batch):
        prices = self._get_prices(request_date, window_size, national_version, batch)

        def map_to_return_ticket_price(route):
            # type: (LandingRoute) -> Optional[ReturnTicketPrice]
            price = prices.get((route.to_id, route.from_id))
            if not price or not price.get('min_price'):
                return
            expires_at = datetime.utcnow() + self.TTL_TIMEDELTA
            return ReturnTicketPrice(
                from_id=route.from_id,
                to_id=route.to_id,
                national_version=national_version,
                price=price['min_price']['value'],
                currency=price['min_price']['currency'],
                date=price['forward_date'],
                expires_at=expires_at,
            )

        month_with_price_records = list(filter(None, map(map_to_return_ticket_price, batch)))
        self._return_ticket_price_table.replace_batch(month_with_price_records)

    def _get_prices(self, request_date, window_size, national_version, batch):
        # type: (date, int, str, List[LandingRoute]) -> dict
        request = {
            'forward_date': request_date.isoformat(),
            'window_size': window_size,
            'results_per_direction': 1,
            'directions': [{'from_id': r.to_id, 'to_id': r.from_id} for r in batch],
        }
        prices = self._price_index_client.top_directions_by_date_window(
            national_version,
            request,
            self._price_index_timeout,
        )
        if not prices:
            logger.error(
                'no min prices. forward_date = %s, window_size = %s, directions = %r',
                request_date,
                window_size,
                request['directions'],
            )
            return {}
        return {(r['from_id'], r['to_id']): r for r in prices}
