# -*- coding: utf-8 -*-
import logging
from datetime import date, timedelta, datetime
from collections import defaultdict, namedtuple
from typing import List, Set, Optional  # noqa

from travel.avia.avia_statistics.alternative_routes_prices_updater.lib.table import (  # noqa
    AlternativeRoutesPriceTable, AlternativeRoutePrice
)
from travel.avia.avia_statistics.lib.settlements_geo_index import SettlementsGeoIndex  # noqa
from travel.avia.avia_statistics.landing_routes import LandingRoute
from travel.avia.avia_statistics.services.price_index.client import Client as PriceIndexClient  # noqa
from travel.library.python.dicts.avia.near_direction_repository import NearDirectionRepository  # noqa

logger = logging.getLogger(__name__)

AlternativeRoute = namedtuple('AlternativeRoute', ('from_id', 'to_id', 'alternative_to_id'))


class AlternativeRoutesPricesUpdater(object):
    TTL_TIMEDELTA = timedelta(days=1, hours=6)

    def __init__(
        self,
        landing_routes,
        near_direction_repository,
        settlements_geo_index,
        alternative_routes_price_table,
        price_index_client,
        price_index_batch_size=100,
        price_index_timeout=5,
        default_distance=200,
    ):
        """
        :param Set[LandingRoute] landing_routes:
        :param NearDirectionRepository near_direction_repository:
        :param SettlementsGeoIndex settlements_geo_index:
        :param AlternativeRoutesPriceTable alternative_routes_price_table:
        :param PriceIndexClient price_index_client:
        :param int price_index_batch_size:
        :param int price_index_timeout:
        :param int default_distance:
        """
        self._landing_routes = landing_routes
        self._near_direction_repository = near_direction_repository
        self._settlements_geo_index = settlements_geo_index
        self._alternative_routes_price_table = alternative_routes_price_table
        self._price_index_client = price_index_client
        self._price_index_batch_size = price_index_batch_size
        self._price_index_timeout = price_index_timeout
        self._default_distance = default_distance

    def update(self, today, window_size=30):
        # type: (date, int) -> None
        self._alternative_routes_price_table.create_if_doesnt_exist()
        window_size //= 2
        request_date = today + timedelta(days=window_size)
        batch_by_national_version = defaultdict(list)
        total_processed = 0
        for route in self._landing_routes:
            batch = batch_by_national_version[route.national_version]
            distance = self._near_direction_repository.get_default_distance(
                route.from_id,
                route.to_id,
                self._default_distance,
            )
            alternative_routes = self._settlements_geo_index.get_nearest(route.to_id, distance)
            batch.extend((
                AlternativeRoute(route.from_id, route.to_id, r.Id)
                for r in alternative_routes
                if LandingRoute(route.from_id, r.Id, route.national_version) in self._landing_routes
            ))
            if len(batch) >= self._price_index_batch_size:
                self._process_batch(request_date, window_size, route.national_version, batch)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
                batch_by_national_version[route.national_version] = []
        for national_version in batch_by_national_version:
            batch = batch_by_national_version[national_version]
            if batch:
                self._process_batch(request_date, window_size, national_version, batch)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
        logger.info('all %s alternative routes prices were stored into YDB', total_processed)

    def _process_batch(self, request_date, window_size, national_version, batch):
        prices = self._get_prices(request_date, window_size, national_version, batch)

        def map_to_alternative_route_price(route):
            # type: (AlternativeRoute) -> Optional[AlternativeRoutePrice]
            price = prices.get((route.from_id, route.alternative_to_id))
            if not price or not price.get('min_price'):
                return
            expires_at = datetime.utcnow() + self.TTL_TIMEDELTA
            return AlternativeRoutePrice(
                from_id=route.from_id,
                to_id=route.to_id,
                alternative_to_id=route.alternative_to_id,
                national_version=national_version,
                price=price['min_price']['value'],
                currency=price['min_price']['currency'],
                date=price['forward_date'],
                expires_at=expires_at,
            )

        month_with_price_records = list(filter(None, map(map_to_alternative_route_price, batch)))
        self._alternative_routes_price_table.replace_batch(month_with_price_records)

    def _get_prices(self, request_date, window_size, national_version, batch):
        # type: (date, int, str, List[AlternativeRoute]) -> dict
        request = {
            'forward_date': request_date.isoformat(),
            'window_size': window_size,
            'results_per_direction': 1,
            'directions': [{'from_id': r.from_id, 'to_id': r.alternative_to_id} for r in batch],
        }
        prices = self._price_index_client.top_directions_by_date_window(
            national_version,
            request,
            self._price_index_timeout,
        )
        if not prices:
            logger.error(
                'no min prices. forward_date = %s, window_size = %s, directions = %r',
                request_date,
                window_size,
                request['directions'],
            )
            return {}
        return {(r['from_id'], r['to_id']): r for r in prices}
