import logging
from asyncio import gather

from travel.rasp.pathfinder_maps.const import TTYPE, WAIT_TIMES
from travel.rasp.pathfinder_maps.protos.builders import build_wait_section, build_pathfinder_protobuf
from travel.rasp.pathfinder_maps.protos.joiners import join_protos
from travel.rasp.pathfinder_maps.protos.utils import (
    get_arrival_time_from_geo_object, get_departure_time_from_geo_object, inject_wait_section_before_start
)
from travel.rasp.pathfinder_maps.utils import RoutePoint


logger = logging.getLogger(__name__)


class BaseJoiner:
    def __init__(self, maps_client, morda_backend_service, nearest_point_finder, protobuf_builder):
        self._maps_client = maps_client
        self._morda_backend_service = morda_backend_service
        self._nearest_point_finder = nearest_point_finder
        self._protobuf_builder = protobuf_builder

    async def join(self, splitted_rll, settlements, dtm, request_context):
        raise NotImplementedError()

    async def _build_variants(self, splitted_rll, variants, departure_times, arrival_stations, dtm, request_context):
        if not variants:
            return None

        variants.sort(key=lambda variant: (variant.routes[-1].arrival_datetime, variant.routes[0].departure_datetime))
        logger.debug('Trying to build: {}'.format(list(map(str, variants))))

        need_single_variant = True
        result = []
        for variant in variants:
            if not need_single_variant:
                break
            if result and need_single_variant and len(variant.routes) != 1:
                continue

            departure_station = RoutePoint(
                variant.routes[0].departure_station['longitude'],
                variant.routes[0].departure_station['latitude']
            )
            pathfinder_departure_time = departure_times.get(departure_station)
            if not pathfinder_departure_time:
                continue

            arrival_station = RoutePoint(
                variant.routes[-1].arrival_station['longitude'],
                variant.routes[-1].arrival_station['latitude']
            )
            if arrival_station not in arrival_stations:
                continue

            station_ttype = TTYPE(variant.routes[0].thread_info[2])
            builded_variant = await self._build_route(
                variant, splitted_rll, departure_station, pathfinder_departure_time, station_ttype, dtm, request_context
            )
            if not builded_variant:
                continue

            inject_wait_section_before_start(builded_variant, dtm)

            result.append((builded_variant, variant))
            if len(variant.routes) == 1:
                need_single_variant = False

        return result if result else None

    async def _build_route(self, variant, splitted_rll, station_coord, pathfinder_departure_time,
                           ttype, dtm, request_context):
        first_maps_points = splitted_rll[0] + [station_coord]

        pathfinder_segment = await self._get_middle_pathfinder_segment(
            variant, pathfinder_departure_time, first_maps_points, dtm, request_context
        )
        if pathfinder_segment is None:
            return None

        pathfinder_protobufs, first_maps_answer = pathfinder_segment

        arrival_station = variant.routes[-1].arrival_station
        second_maps_departure_point = RoutePoint(lon=arrival_station['longitude'], lat=arrival_station['latitude'])

        last_maps_points = [second_maps_departure_point] + splitted_rll[1]
        last_maps_answer = await self._get_last_maps_answer(last_maps_points, variant, ttype, request_context)
        if not last_maps_answer:
            return None

        return join_protos([first_maps_answer, pathfinder_protobufs, last_maps_answer])

    async def _get_last_maps_answer(self, points, variant, ttype, request_context):
        last_maps_answer = await self._maps_client.route(points, dtm=variant.routes[-1].arrival_datetime)
        if not len(last_maps_answer.reply.geo_object[:]):
            request_context.set_fail_reason('Empty last_maps_answer for {}: {}'.format(
                ttype.name, ' -> '.join(x.as_string for x in points)
            ))
            return None
        return last_maps_answer

    async def _get_middle_pathfinder_segment(self, variant, pathfinder_departure_time, points, dtm, request_context):
        pathfinder_segment = await self._get_pathfinder_segment(variant, pathfinder_departure_time, points, dtm)
        if pathfinder_segment is None:
            request_context.set_fail_reason('No pathfinder_segment for {}'.format(
                ' -> '.join(route.thread_id for route in variant.routes[::2])
            ))
            return None
        return pathfinder_segment

    async def _get_pathfinder_departure_time(self, points, dtm, ttype, request_context):
        first_maps_answer = await self._maps_client.route(points, dtm=dtm)
        if not len(first_maps_answer.reply.geo_object[:]):
            request_context.set_fail_reason('Empty first_maps_answer for {}: {}'.format(
                ttype.name, ' -> '.join(x.as_string for x in points)
            ))
            return None
        first_maps_segment = sorted(first_maps_answer.reply.geo_object[:], key=get_arrival_time_from_geo_object)[0]
        maps_arrival_time = get_arrival_time_from_geo_object(first_maps_segment)
        return maps_arrival_time + WAIT_TIMES[ttype]

    async def _get_pathfinder_segment(self, variant, pathfinder_departure_time, maps_coord, dtm):
        variant_dt = variant.routes[0].departure_datetime
        if variant_dt <= pathfinder_departure_time:
            return None

        pathfinder_protobuf = await self._protobuf_builder.build_pathfinder_protobuf(variant)
        if pathfinder_protobuf is None:
            return None
        points, sections = pathfinder_protobuf

        wait_time = WAIT_TIMES[TTYPE(variant.routes[0].thread_info[2])]
        first_maps_answer = await self._maps_client.route(maps_coord, atm=variant_dt - wait_time)
        first_maps_segment = first_maps_answer.reply.geo_object[0]
        first_maps_segment_departure = get_departure_time_from_geo_object(first_maps_segment)
        if first_maps_segment_departure < dtm:
            return None

        first_maps_segment_arrival = get_arrival_time_from_geo_object(first_maps_segment)
        wait_section = build_wait_section((variant_dt - first_maps_segment_arrival).total_seconds())
        pathfinder_protobuf = build_pathfinder_protobuf(points, [wait_section] + sections)
        return pathfinder_protobuf, first_maps_answer

    async def _get_departure_times(self, variants, splitted_rll, dtm, request_context):
        departure_stations = set()
        for ttype_variants in variants:
            for variant in ttype_variants:
                departure_ttype = TTYPE(variant.routes[0].thread_info[2])
                departure_station = RoutePoint(
                    variant.routes[0].departure_station['longitude'],
                    variant.routes[0].departure_station['latitude']
                )
                departure_stations.add((departure_station, departure_ttype))

        departure_times = {}
        for departure_station, departure_ttype in departure_stations:
            first_maps_points = splitted_rll[0] + [departure_station]
            departure_times[departure_station] = await self._get_pathfinder_departure_time(
                first_maps_points, dtm, departure_ttype, request_context
            )

        return departure_times

    async def _get_arrival_stations(self, variants, splitted_rll):
        arrival_stations = {}
        for ttype_variants in variants:
            for variant in ttype_variants:
                arrival_station = RoutePoint(
                    variant.routes[-1].arrival_station['longitude'],
                    variant.routes[-1].arrival_station['latitude']
                )
                arrival_stations[arrival_station] = variant.routes[-1].arrival_datetime
        arrival_stations = list(arrival_stations.items())
        arrival_segments = await gather(*[
            self._maps_client.route([arrival_station] + splitted_rll[-1], dtm=arrival_time)
            for arrival_station, arrival_time in arrival_stations
        ])
        return {
            arrival_station for (arrival_station, _), maps_answer in zip(arrival_stations, arrival_segments)
            if len(maps_answer.reply.geo_object[:])
        }
