# -*- coding: utf-8 -*-
import logging
from typing import List

from yt.wrapper import YtClient

from travel.avia.library.python.backend_client import BackendClient
from travel.avia.avia_statistics.updaters.city_to.nearest_cities_updater.lib.table import (
    CityToNearestCities, CityToNearestCitiesTable
)
from travel.avia.avia_statistics.landing_cities import LandingCity

logger = logging.getLogger(__name__)


class CityToNearestCitiesUpdater(object):
    def __init__(
            self,
            yt_client: YtClient,
            landing_cities: List[LandingCity],
            city_to_nearest_cities_table: CityToNearestCitiesTable,
            avia_backend_client: BackendClient,
            ydb_batch_size: int = 1000,
            default_max_distance_km: int = 200,
    ):
        self._yt_client = yt_client
        self._landing_cities = landing_cities
        self._city_to_nearest_cities_table = city_to_nearest_cities_table
        self._avia_backend_client = avia_backend_client
        self._ydb_batch_size = ydb_batch_size
        self._default_max_distance_km = default_max_distance_km

    def update(self) -> None:
        self._city_to_nearest_cities_table.create_if_doesnt_exist()

        nearest_cities = self._get_nearest_cities()

        batch = []
        total_processed = 0
        logger.info('start writing nearest cities for city-to-landing')
        for record in nearest_cities:
            batch.append(record)
            if len(batch) == self._ydb_batch_size:
                self._process_batch(batch)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
                batch = []
        if batch:
            self._process_batch(batch)
            total_processed += len(batch)
            logger.info('processed: %s', total_processed)
        logger.info('all %s nearest cities for city-to-landing were stored into YDB', total_processed)

    def _process_batch(self, batch: List[CityToNearestCities]) -> None:
        self._city_to_nearest_cities_table.replace_batch(batch)

    def _get_nearest_cities(self) -> List[CityToNearestCities]:
        nearest_city_ids_by_city_to = {}

        size = len(self._landing_cities)
        all_city_to_ids = {c.to_id for c in self._landing_cities}
        for i, city in enumerate(self._landing_cities):
            nearest_city_ids = self._avia_backend_client.nearest_settlements(
                city.to_id,
                max_distance=self._default_max_distance_km,
            )
            nearest_city_ids_by_city_to[city] = [id_ for id_ in nearest_city_ids if id_ in all_city_to_ids]
            if i % 50 == 0:
                logger.info('done get nearest cities for %s/%s cities', i, size)

        records = []
        for city, nearest_ids in nearest_city_ids_by_city_to.items():
            records.append(CityToNearestCities(
                to_id=city.to_id,
                national_version=city.national_version,
                nearest_city_ids=','.join(map(str, nearest_ids)),
            ))

        return records
