# -*- coding: utf-8 -*-
import logging
import json
from datetime import date, timedelta
from collections import defaultdict
from itertools import chain

from typing import List

import numpy as np

from travel.avia.avia_statistics.updaters.city_to.city_route_crosslinks_updater.lib.table import (  # noqa
    CityRouteCrosslinksTable, CityRouteCrosslink
)
from travel.avia.avia_statistics.updaters.city_to.city_route_crosslinks_updater.lib.crosslinks_provider import (
    CrosslinksProvider, RoutesDict
)
from travel.avia.avia_statistics.landing_cities import LandingCity  # noqa
from travel.avia.avia_statistics.landing_routes import LandingRoute
from travel.avia.avia_statistics.services.price_index.client import Client as PriceIndexClient  # noqa

logger = logging.getLogger(__name__)


class CityRouteCrosslinksUpdater(object):
    def __init__(
        self,
        yt_client,
        landing_cities,
        landing_routes,
        route_weights,
        settlement_repository,
        city_route_crosslinks_table,
        avia_backend_client,
        price_index_client,
        price_index_batch_size=100,
        price_index_timeout=5,
        ydb_batch_size=1000,
    ):
        """
        :type yt_client: yt.wrapper.YtClient
        :type landing_cities: List[LandingCity]
        :type landing_routes: List[LandingRoute]
        :type route_weights: Dict[LandingRoute, RouteWeight]
        :type settlement_repository: travel.library.python.dicts.avia.settlement_repository.SettlementRepository
        :type city_route_crosslinks_table: CityRouteCrosslinksTable
        :type avia_backend_client: BackendClient
        :type price_index_client: PriceIndexClient
        :param int price_index_batch_size:
        :param int price_index_timeout:
        :param int ydb_batch_size:
        """
        self._yt_client = yt_client
        self._landing_cities = landing_cities
        self._landing_routes = landing_routes
        self._route_weights = route_weights
        self._settlement_repository = settlement_repository
        self._city_route_crosslinks_table = city_route_crosslinks_table
        self._price_index_client = price_index_client
        self._price_index_batch_size = price_index_batch_size
        self._price_index_timeout = price_index_timeout
        self._avia_backend_client = avia_backend_client
        self._ydb_batch_size = ydb_batch_size

    def update(self, today, window_size=30, save_crosslinks_to_file=False):
        # type: (date, int, bool) -> None
        np.random.seed(1)

        crosslinks_generator = CrosslinksProvider(
            self._landing_cities,
            self._landing_routes,
            self._route_weights,
            self._settlement_repository,
            self._avia_backend_client,
            self._yt_client,
        )
        crosslinks = crosslinks_generator.generate()

        if save_crosslinks_to_file:
            f_name = 'to_city_route_crosslinks.json'
            with open(f_name, 'w') as f:
                json.dump(RoutesDict.as_json(crosslinks), f)
            logger.info('Crosslinks saved into %s file', f_name)
        self._city_route_crosslinks_table.create_if_doesnt_exist()

        routes = chain.from_iterable(crosslinks.values())
        price_by_route = self._get_price_by_route(routes, today, window_size)

        batch = []
        crosslinks_processed = 0
        for city in crosslinks:
            for position, crosslink in enumerate(crosslinks[city]):
                price = price_by_route.get(crosslink)
                price_value = None
                currency = None
                forward_date = None
                if price and price.get('min_price'):
                    price_value = price['min_price']['value']
                    currency = price['min_price']['currency']
                    forward_date = price['forward_date']

                batch.append(CityRouteCrosslink(
                    city.to_id,
                    city.national_version,
                    position,
                    crosslink.from_id,
                    crosslink.to_id,
                    price_value,
                    currency,
                    forward_date,
                ))
                crosslinks_processed += 1
                if len(batch) == self._ydb_batch_size:
                    self._city_route_crosslinks_table.replace_batch(batch)
                    batch = []
                if crosslinks_processed % 1000 == 0:
                    logger.info('Processed %d crosslinks', crosslinks_processed)
        if len(batch) > 0:
            self._city_route_crosslinks_table.replace_batch(batch)
        logger.info('All %d to_city landing route crosslinks have been processed', crosslinks_processed)

    def _process_batch(self, request_date, window_size, national_version, batch, price_by_route):
        prices = self._get_prices(request_date, window_size, national_version, batch)
        for route in batch:
            price_by_route[route] = prices.get((route.from_id, route.to_id))

    def _get_price_by_route(self, landing_routes, today, window_size):
        logger.info('Getting price for landing cities')
        window_size //= 2
        request_date = today + timedelta(days=window_size)
        batch_by_national_version = defaultdict(list)
        price_by_route = {}
        total_processed = 0
        for route in landing_routes:
            batch = batch_by_national_version[route.national_version]

            batch.append(route)
            if len(batch) >= self._price_index_batch_size:
                self._process_batch(request_date, window_size, route.national_version, batch, price_by_route)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
                batch_by_national_version[route.national_version] = []
        for national_version in batch_by_national_version:
            batch = batch_by_national_version[national_version]
            if batch:
                self._process_batch(request_date, window_size, national_version, batch, price_by_route)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
        logger.info(
            '%d with prices out of %d routes',
            sum(bool(p and p.get('min_price')) for p in price_by_route.values()),
            len(price_by_route),
        )
        return price_by_route

    def _get_prices(self, request_date, window_size, national_version, batch):
        # type: (date, int, str, List[LandingRoute]) -> dict
        request = {
            'forward_date': request_date.isoformat(),
            'window_size': window_size,
            'results_per_direction': 1,
            'directions': [{'from_id': r.from_id, 'to_id': r.to_id} for r in batch],
        }
        prices = self._price_index_client.top_directions_by_date_window(
            national_version,
            request,
            self._price_index_timeout,
        )
        if not prices:
            logger.error(
                'no min prices. forward_date = %s, window_size = %s, directions = %r',
                request_date,
                window_size,
                request['directions'],
            )
            return {}
        return {(r['from_id'], r['to_id']): r for r in prices}
