# -*- coding: utf-8 -*-
import logging
import json
import time
from datetime import date, timedelta
from collections import defaultdict
from typing import List, Optional

import numpy as np

from yt.wrapper import YtClient  # noqa

from travel.avia.avia_statistics.route_crosslinks_updater.lib.table import RouteCrosslinksTable, RouteCrosslink  # noqa
from travel.avia.avia_statistics.route_crosslinks_updater.lib.crosslinks_provider import CrosslinksProvider, RoutesDict
from travel.avia.avia_statistics.lib.settlements_geo_index import SettlementsGeoIndex  # noqa
from travel.avia.avia_statistics.landing_routes import LandingRoute
from travel.avia.avia_statistics.services.price_index.client import Client as PriceIndexClient  # noqa
from travel.library.python.dicts.avia.near_direction_repository import NearDirectionRepository  # noqa

logger = logging.getLogger(__name__)


class RouteCrosslinksUpdater(object):
    def __init__(
        self,
        yt_client,
        landing_routes,
        route_weights,
        near_direction_repository,
        settlements_geo_index,
        route_crosslinks_table,
        price_index_client,
        price_index_batch_size=100,
        price_index_timeout=5,
        default_distance=200,
        ydb_batch_size=1000,
        throttle_time_ms=0,
    ):
        """
        :param List[LandingRoute] landing_routes:
        :param Dict[LandingRoute, RouteWeight] route_weights:
        :param NearDirectionRepository near_direction_repository:
        :param SettlementsGeoIndex settlements_geo_index:
        :param RouteCrosslinksTable route_crosslinks_table:
        :param PriceIndexClient price_index_client:
        :param int price_index_batch_size:
        :param int ydb_batch_size:
        :param int price_index_timeout:
        :param int default_distance:
        :param YtClient yt_client:
        """
        self._landing_routes = landing_routes
        self._route_weights = route_weights
        self._near_direction_repository = near_direction_repository
        self._settlements_geo_index = settlements_geo_index
        self._route_crosslinks_table = route_crosslinks_table
        self._price_index_client = price_index_client
        self._price_index_batch_size = price_index_batch_size
        self._price_index_timeout = price_index_timeout
        self._ydb_batch_size = ydb_batch_size
        self._default_distance = default_distance
        self._yt_client = yt_client
        self._throttle_time_ms = throttle_time_ms

    def update(self, today, window_size=30, save_crosslinks_to_file=False, yt_crosslinks_table='//home/avia/avia-statistics/landing-route-crosslinks'):
        # type: (date, int, bool, Optional[str]) -> None
        np.random.seed(1)

        price_by_route = self._get_price_by_route(today, window_size)

        crosslinks_generator = CrosslinksProvider(
            self._landing_routes,
            self._route_weights,
            self._default_distance,
            self._near_direction_repository,
            self._settlements_geo_index,
            self._yt_client,
        )
        if yt_crosslinks_table is not None:
            logger.info('Reading crosslinks from YT table')
            crosslinks = crosslinks_generator.read_from_yt(yt_crosslinks_table)
        else:
            crosslinks = crosslinks_generator.generate()

        if save_crosslinks_to_file:
            with open('crosslinks.json', 'w') as f:
                json.dump(RoutesDict.as_json(crosslinks), f)

        self._route_crosslinks_table.create_if_doesnt_exist()

        batch = []
        crosslinks_processed = 0
        for route in crosslinks:
            for position, crosslink in enumerate(crosslinks[route]):
                price = price_by_route.get(crosslink)
                price_value = None
                currency = None
                forward_date = None
                if price and price.get('min_price'):
                    price_value = price['min_price']['value']
                    currency = price['min_price']['currency']
                    forward_date = price['forward_date']

                batch.append(RouteCrosslink(
                    route.from_id,
                    route.to_id,
                    route.national_version,
                    position,
                    crosslink.from_id,
                    crosslink.to_id,
                    price_value,
                    currency,
                    forward_date,
                ))
                crosslinks_processed += 1
                if len(batch) == self._ydb_batch_size:
                    self._route_crosslinks_table.replace_batch(batch)
                    batch = []
                if crosslinks_processed % 1000 == 0:
                    logger.info('Processed %d crosslinks', crosslinks_processed)
        if len(batch) > 0:
            self._route_crosslinks_table.replace_batch(batch)
        logger.info('All %d crosslinks have been processed', crosslinks_processed)

    def _process_batch(self, request_date, window_size, national_version, batch, price_by_route):
        prices = self._get_prices(request_date, window_size, national_version, batch)
        for route in batch:
            price_by_route[route] = prices.get((route.from_id, route.to_id))

    def _get_price_by_route(self, today, window_size):
        logger.info('Getting price for landing routes')
        window_size //= 2
        request_date = today + timedelta(days=window_size)
        batch_by_national_version = defaultdict(list)
        price_by_route = {}
        total_processed = 0
        for route in self._landing_routes:
            batch = batch_by_national_version[route.national_version]

            batch.append(route)
            if len(batch) >= self._price_index_batch_size:
                self._process_batch(request_date, window_size, route.national_version, batch, price_by_route)
                if self._throttle_time_ms:
                    time.sleep(self._throttle_time_ms / 1000.0)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
                batch_by_national_version[route.national_version] = []
        for national_version in batch_by_national_version:
            batch = batch_by_national_version[national_version]
            if batch:
                self._process_batch(request_date, window_size, national_version, batch, price_by_route)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
        logger.info(
            '%d with prices out of %d routes',
            sum(bool(p and p.get('min_price')) for p in price_by_route.values()),
            len(price_by_route),
        )
        return price_by_route

    def _get_prices(self, request_date, window_size, national_version, batch):
        # type: (date, int, str, List[LandingRoute]) -> dict
        request = {
            'forward_date': request_date.isoformat(),
            'window_size': window_size,
            'results_per_direction': 1,
            'directions': [{'from_id': r.from_id, 'to_id': r.to_id} for r in batch],
        }
        prices = self._price_index_client.top_directions_by_date_window(
            national_version,
            request,
            self._price_index_timeout,
        )
        if not prices:
            logger.error(
                'no min prices. forward_date = %s, window_size = %s, directions = %r',
                request_date,
                window_size,
                request['directions'],
            )
            return {}
        return {(r['from_id'], r['to_id']): r for r in prices}
