# -*- coding: utf-8 -*-
import logging
from typing import List, Set, Optional, Dict  # noqa
from collections import defaultdict

import six

from travel.avia.avia_statistics.route_infos_updater.lib.table import (  # noqa
    RouteInfosTable, RouteInfo
)
from travel.avia.avia_statistics.lib.settlements_geo_index import SettlementsGeoIndex  # noqa
from travel.avia.avia_statistics.landing_routes import LandingRoute
from travel.avia.library.python.shared_flights_client.client import SharedFlightsClient  # noqa
from travel.library.python.dicts.avia.station_repository import StationRepository  # noqa
from travel.avia.library.python.references.station_to_settlement import StationToSettlementCache  # noqa
from travel.avia.library.python.references.station import StationCache  # noqa

logger = logging.getLogger(__name__)


class RouteInfosUpdater(object):
    def __init__(
        self,
        landing_routes,
        station_repository,
        station_to_settlement_cache,
        station_cache,
        settlements_geo_index,
        route_infos_table,
        shared_flights_client,
        batch_size=1000,
    ):
        """
        :param Set[LandingRoute] landing_routes:
        :param StationRepository station_repository:
        :param StationToSettlementCache station_to_settlement_cache:
        :param StationCache station_cache:
        :param SettlementsGeoIndex settlements_geo_index:
        :param RouteInfosTable route_infos_table:
        :param SharedFlightsClient shared_flights_client:
        :param int batch_size:
        """
        self._landing_routes = landing_routes
        self._station_repository = station_repository
        self._station_to_settlement_cache = station_to_settlement_cache
        self._station_cache = station_cache
        self._settlements_geo_index = settlements_geo_index
        self._route_infos_table = route_infos_table
        self._shared_flights_client = shared_flights_client
        self._batch_size = batch_size

    def update(self):
        self._route_infos_table.create_if_doesnt_exist()
        batch = []
        total_processed = 0
        plane_ttype_id = 2
        settlement_id_by_station_id = {
            station.Id: self._get_settlement_id(station.Id)
            for station in self._station_repository.itervalues()
            if (
                self._get_settlement_id(station.Id) is not None and
                station.MajorityId and station.MajorityId <= 2 and
                station.TTypeId == plane_ttype_id and
                not station.Hidden
            )
        }
        station_ids_by_settlement_id = defaultdict(list)
        for station, settlement in six.iteritems(settlement_id_by_station_id):
            station_ids_by_settlement_id[settlement].append(station)

        for route in self._landing_routes:
            batch.append(route)
            if len(batch) == self._batch_size:
                self._process_batch(batch, station_ids_by_settlement_id)
                total_processed += len(batch)
                logger.info('processed: %s', total_processed)
                batch = []
        if batch:
            self._process_batch(batch, station_ids_by_settlement_id)
            total_processed += len(batch)
            logger.info('processed: %s', total_processed)
        logger.info('all %s route infos were stored into YDB', total_processed)

    def _process_batch(self, batch, station_ids_by_settlement_id):
        # type: (List[LandingRoute], Dict[int, List[int]]) -> None

        def map_to_route_info(route):
            # type: (LandingRoute) -> Optional[RouteInfo]
            from_airports = station_ids_by_settlement_id.get(route.from_id, [])
            to_airports = station_ids_by_settlement_id.get(route.to_id, [])
            return RouteInfo(
                from_id=route.from_id,
                to_id=route.to_id,
                distance=self._settlements_geo_index.get_distance(route.from_id, route.to_id),
                duration=self._get_duration(from_airports, to_airports),
                from_airports=from_airports,
                to_airports=to_airports,
            )

        route_infos = list(filter(None, map(map_to_route_info, batch)))
        self._route_infos_table.replace_batch(route_infos)

    def _get_settlement_id(self, staion_id):
        return self._station_repository.get(staion_id).SettlementId or \
            self._station_to_settlement_cache.settlement_id_by_id(staion_id, raise_on_unknown=False)

    def _get_duration(self, from_airports, to_airports):
        # type: (List[int], List[int]) -> Optional[int]
        from_codes = filter(
            None,
            [self._station_cache.station_code_by_id(_id, raise_on_unknown=False) for _id in from_airports],
        )
        to_codes = filter(
            None,
            [self._station_cache.station_code_by_id(_id, raise_on_unknown=False) for _id in to_airports],
        )
        if not from_codes or not to_codes:
            return 0
        segments_info = self._shared_flights_client.flight_p2p_segment_info(from_codes, to_codes)
        durations = [s['min_duration'] for s in segments_info if s['min_duration']]
        if not durations:
            return None
        return min(durations)
