import logging
from asyncio import gather

from travel.rasp.pathfinder_maps.const import TTYPE
from travel.rasp.pathfinder_maps.protos.joiners import join_results
from travel.rasp.pathfinder_maps.protos.utils import (add_link, get_arrival_time_from_geo_object, get_earliest_route,
                                                      get_fastest_route, get_empty_result, get_travel_time_from_route)
from travel.rasp.pathfinder_maps.services.polling_services.polling_cache import PollingCache

logger = logging.getLogger(__name__)


class PathJoiner:
    def __init__(
        self,
        maps_client,
        polling_cache: PollingCache,
        nearest_point_finder,
        rasp_search_link,
        maps_joiner,
        pathfinder_joiner,
        travel_time_coef
    ):
        self._maps_client = maps_client
        self._polling_cache = polling_cache
        self._nearest_point_finder = nearest_point_finder
        self._rasp_search_link = rasp_search_link
        self._maps_joiner = maps_joiner
        self._pathfinder_joiner = pathfinder_joiner
        self._travel_time_coef = travel_time_coef

    async def _split_rll(self, rll):
        tasks = [self._nearest_point_finder.get_settlement_from_geobase(x) for x in rll]
        settlements = await gather(*tasks)
        last_settlement = settlements[0]
        rll_part = [rll[0]]
        res_points, res_settlements = [], []

        for settlement, point in zip(settlements[1:], rll[1:]):
            if last_settlement != settlement:
                res_points.append(rll_part)
                res_settlements.append(last_settlement)

                last_settlement = settlement
                rll_part = []
            rll_part.append(point)
        res_points.append(rll_part)
        res_settlements.append(settlement)

        return res_points, res_settlements

    async def _get_routes(self, splitted_rll, settlements, dtm, request_context):
        if settlements[0] is None:
            return await self._maps_joiner.join(splitted_rll, settlements, dtm, request_context), False
        elif settlements[1] is None:
            settlements[1] = await self._nearest_point_finder.get_settlement(splitted_rll[1][0])
        return await self._pathfinder_joiner.join(splitted_rll, settlements, dtm, request_context), True

    def _filter_ground_routes(self, routes, earliest_ground_route, fastest_ground_route):
        result = [earliest_ground_route]
        if earliest_ground_route[1] != fastest_ground_route[1]:
            result.append(fastest_ground_route)

        max_travel_time = get_travel_time_from_route(fastest_ground_route[0]).total_seconds() * self._travel_time_coef

        for route in routes:
            if route is earliest_ground_route or route is fastest_ground_route:
                continue
            route_travel_time = get_travel_time_from_route(route[0]).total_seconds()
            if route_travel_time < max_travel_time:
                result.append(route)
            else:
                logger.debug('{} filtered'.format(str(route[1])))
        return result

    def _filter_plane_routes(self, routes, earliest_ground_route, fastest_ground_route):
        result = []

        earliest_plane_route = None
        earliest_ground_arrival = get_arrival_time_from_geo_object(earliest_ground_route[0].reply.geo_object[-1])
        earliest_ground_segments_len = len(earliest_ground_route[1].routes[::2])

        for route in routes:
            route_arrival = get_arrival_time_from_geo_object(route[0].reply.geo_object[-1])

            if route_arrival < earliest_ground_arrival and len(route[1].routes[::2]) + 3 < earliest_ground_segments_len:
                earliest_plane_route = route
                result.append(route)
                break

        max_travel_time = get_travel_time_from_route(fastest_ground_route[0]).total_seconds()
        for route in routes:
            if earliest_plane_route and route is earliest_plane_route:
                continue
            route_travel_time = get_travel_time_from_route(route[0]).total_seconds()
            if route_travel_time < max_travel_time:
                result.append(route)
            else:
                logger.debug('{} filtered'.format(str(route[1])))

        return result

    def _filter_routes(self, routes):
        ground_routes, plane_routes = [], []
        ground_variants, plane_variants = set(), set()
        for route, variant in routes:
            if any(TTYPE(route.thread_info[2]) is TTYPE.aero for route in variant.routes[::2]):
                if variant in plane_variants:
                    continue
                plane_routes.append((route, variant))
                plane_variants.add(variant)
            else:
                if variant in ground_variants:
                    continue
                ground_routes.append((route, variant))
                ground_variants.add(variant)

        if not ground_routes:
            return plane_routes

        earliest_ground_route = get_earliest_route(ground_routes)
        fastest_ground_route = get_fastest_route(ground_routes)
        ground_routes = self._filter_ground_routes(ground_routes, earliest_ground_route, fastest_ground_route)
        plane_routes = self._filter_plane_routes(plane_routes, earliest_ground_route, fastest_ground_route)

        return ground_routes + plane_routes

    def _range_routes(self, routes):
        res = []
        while True:
            if not routes:
                break
            routes.sort(key=lambda x: get_arrival_time_from_geo_object(x[0].reply.geo_object[0]))
            res.append(routes.pop(0))

            if not routes:
                break
            routes.sort(key=lambda x: get_travel_time_from_route(x[0]))
            res.append(routes.pop(0))

        return res

    async def init_polling(self, variants):
        routes_to_poll = []
        for variant in variants:
            for route in variant.routes:
                if route.thread_id == 'NULL':
                    continue
                routes_to_poll.append(route)
        logger.debug(f'Polling keys: {", ".join(route.polling_key for route in routes_to_poll)}')
        await gather(*[self._polling_cache.get_from_cache_or_init(route) for route in routes_to_poll])

    async def join(self, rll, dtm, request_context):
        maps_answer = await self._maps_client.route(rll, dtm=dtm)
        if len(maps_answer.reply.geo_object[:]):
            request_context.set_maps_answer()
            return maps_answer

        splitted_rll, settlements = await self._split_rll(rll)
        if len(splitted_rll) != 2:  # ограничение на 1 пересадочный сегмент
            return get_empty_result()

        routes_by_ttype, need_link = await self._get_routes(splitted_rll, settlements, dtm, request_context)

        result = []
        for routes in routes_by_ttype:
            for route in routes:
                result.append(route)

        logger.debug('Builded variants: {}'.format([str(x[1]) for x in result]))
        result = self._filter_routes(result)
        logger.debug('After filter: {}'.format([str(x[1]) for x in result]))
        result = self._range_routes(result)
        logger.debug('After ranging: {}'.format([str(x[1]) for x in result]))

        if not len(result):
            request_context.set_failed()
            return get_empty_result()

        for variant in {x[1] for x in result}:
            request_context.add_variant(variant)

        await self.init_polling([x[1] for x in result])

        response = join_results([x[0] for x in result])
        if need_link:
            add_link(response, self._rasp_search_link, settlements[0][1], settlements[1][1], dtm)
        return response
