# -*- coding: utf-8 -*-
import logging
from typing import List, Set, Optional, Dict  # noqa

from travel.avia.avia_statistics.routes_updater.lib.table import RoutesTable, Route  # noqa
from travel.avia.avia_statistics.landing_routes import LandingRoute

logger = logging.getLogger(__name__)


class RoutesUpdater(object):
    def __init__(
        self,
        landing_routes,
        routes_table,
        batch_size=1000,
    ):
        """
        :param Set[LandingRoute] landing_routes:
        :param RoutesTable routes_table:
        :param int batch_size:
        """
        self._landing_routes = landing_routes
        self._routes_table = routes_table
        self._batch_size = batch_size

    def update(self):
        self._routes_table.create_if_doesnt_exist()
        batch = []
        total_processed = 0

        for route in self._landing_routes:
            batch.append(route)
            if len(batch) == self._batch_size:
                self._process_batch(batch)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
                batch = []
        if batch:
            self._process_batch(batch)
            total_processed += len(batch)
            logger.info('processed: %s', total_processed)
        logger.info('all %s routes were stored into YDB', total_processed)

    def _process_batch(self, batch):
        # type: (List[LandingRoute], Dict[int, List[int]]) -> None

        def map_to_route(route):
            # type: (LandingRoute) -> Optional[Route]
            return Route(
                from_id=route.from_id,
                to_id=route.to_id,
                national_version=route.national_version,
            )

        routes = list(filter(None, map(map_to_route, batch)))
        self._routes_table.replace_batch(routes)
