import logging
from asyncio import gather

from travel.rasp.pathfinder_maps.const import TTYPE
from travel.rasp.pathfinder_maps.services.joiners.base_joiner import BaseJoiner

logger = logging.getLogger(__name__)


class MapsJoiner(BaseJoiner):
    PATHFINDER_TTYPES = {
        TTYPE.train: (TTYPE.train, TTYPE.bus, TTYPE.suburban),
        TTYPE.bus: (TTYPE.bus, TTYPE.suburban, TTYPE.train)
    }

    async def join(self, splitted_rll, settlements, dtm, request_context):
        request_context.maps_first()
        first_point, second_point = splitted_rll[0][-1], splitted_rll[1][0]
        maps_coords, first_points = await self._nearest_point_finder.get_stations_by_ttype(first_point)

        if settlements[1] is None:
            _, second_points = await self._nearest_point_finder.get_stations_by_ttype(second_point)
        else:
            second_points = {ttype: [settlements[1]] for ttype in [TTYPE.aero, TTYPE.train, TTYPE.bus]}
        request_context.set_pathfinder_stations(first_points, second_points)

        searches = []
        for ttype in self.PATHFINDER_TTYPES:
            for second_point in second_points[ttype]:
                for first_point in first_points[ttype]:
                    query = {
                        'from_type': first_point[0],
                        'from_id': first_point[1],
                        'to_type': second_point[0],
                        'to_id': second_point[1],
                        'date': dtm
                    }
                    searches.append(self._morda_backend_service.search(query, self.PATHFINDER_TTYPES[ttype]))
                    searches.append(self._morda_backend_service.search(query))
        searches = await gather(*searches)

        variants = []
        for restricted_ttypes_variants, all_ttypes_variants in zip(searches[::2], searches[1::2]):
            variants.append(restricted_ttypes_variants)
            variants.append(list(set(all_ttypes_variants) - set(restricted_ttypes_variants)))

        variants = self._protobuf_builder.filter_variants(variants)

        departure_times = await self._get_departure_times(variants, splitted_rll, dtm, request_context)
        arrival_stations = await self._get_arrival_stations(variants, splitted_rll)
        result = await gather(*[
            self._build_variants(splitted_rll, ttype_variants, departure_times, arrival_stations, dtm, request_context)
            for ttype_variants in variants
        ])
        return [x for x in result if x is not None]
