import logging

import numpy as np
from typing import List

from travel.rasp.pathfinder_maps.const import TTYPE, WAIT_TIMES
from travel.rasp.pathfinder_maps.logs import log_missing_geometry
from travel.rasp.pathfinder_maps.maps_protos.common2.geo_object_pb2 import GeoObject
from travel.rasp.pathfinder_maps.models.route import Route
from travel.rasp.pathfinder_maps.models.variant import Variant
from travel.rasp.pathfinder_maps.protos.builders import (
    build_pathfinder_section, build_railway_stop, build_railway_polyline, build_stop_protobuf, build_wait_section
)
from travel.rasp.pathfinder_maps.protos.utils import get_arrival_time_from_geo_object
from travel.rasp.pathfinder_maps.utils import RoutePoint


log = logging.getLogger(__name__)


class ProtobufBuilder:
    def __init__(self, protobuf_data_provider, pickle_data_provider, maps_client, rasp_link):
        self._protobuf_data_provider = protobuf_data_provider
        self._pickle_data_provider = pickle_data_provider
        self._maps_client = maps_client
        self._rasp_link = rasp_link

    def _build_thread_name(self, title_parts):
        res = []
        for title_part in title_parts:
            settlement_id, station_id = title_part
            settlement = self._protobuf_data_provider.settlement_repo.get(settlement_id) if settlement_id else None
            station = self._protobuf_data_provider.station_repo.get(station_id) if station_id else None

            if settlement and station:
                res.append('{} ({})'.format(settlement.TitleDefault, station.TitleDefault))
            elif settlement:
                res.append(settlement.TitleDefault)
            elif station:
                res.append(station.TitleDefault)
        return ' - '.join(res)

    async def _get_maps_teleportation(self, departure_point, arrival_point, dtm, atm):
        res = await self._maps_client.route((departure_point, arrival_point), dtm=dtm)
        for geo_object in res.reply.geo_object[:]:
            arrival_time = get_arrival_time_from_geo_object(geo_object)
            if arrival_time < atm:
                return geo_object.geo_object[:], arrival_time

    def segments_iter(self, route):
        for stop1, stop2 in zip(route.thread_stations, route.thread_stations[1:]):
            if route.thread_type is TTYPE.bus:
                segment = self._pickle_data_provider.limepaths_geometry.get((stop1, stop2))
            else:
                segment = (
                    self._pickle_data_provider.railway_geometry.get((stop1, stop2, route.thread_name))
                    or
                    self._pickle_data_provider.railway_geometry.get((stop1, stop2, None))
                )

            if segment is None or segment[0] is None:
                segment = None

            yield segment, stop1, stop2

    def _is_valid_variant(self, variant):
        # type: (ProtobufBuilder, Variant) -> bool

        for i, route in enumerate(variant.routes):
            if route.is_maps_teleportation:
                prev_route, next_route = variant.routes[i - 1], variant.routes[i + 1]
                maps_dtm = prev_route.arrival_datetime
                maps_atm = next_route.departure_datetime - WAIT_TIMES[next_route.thread_type]
                if maps_atm < maps_dtm:
                    log.warning('No time for wait section between {} and {}'.format(prev_route.thread_id, next_route.thread_id))
                    return False

            if route.thread_id == 'NULL':
                continue

            departure_station, arrival_station = route.departure_station['id'], route.arrival_station['id']

            if not (departure_station in route.thread_stations and arrival_station in route.thread_stations):
                log.warning('departure or arrival not in thread_stations for thread {}'.format(route.thread_id))
                return False

            if route.thread_type is TTYPE.aero:
                continue

            for segment, stop_from, stop_to in self.segments_iter(route):
                if segment is None:
                    log_missing_geometry(route.thread_id, {
                        'thread_type': route.thread_type,
                        'stop_from': stop_from,
                        'stop_to': stop_to,
                        'thread_name': route.thread_name,
                        'thread_title': route.thread_title
                    })
                    return False

        return True

    def filter_variants(self, typed_variants_list):
        # type: (ProtobufBuilder, List[List[Variant]]) -> List[List[Variant]]

        res_variants = []
        for typed_variants in typed_variants_list:
            filtered_variants = [variant for variant in typed_variants if self._is_valid_variant(variant)]
            if filtered_variants:
                res_variants.append(filtered_variants)
        return res_variants

    def _parse_thread(self, route):
        # type: (ProtobufBuilder, Route) -> tuple[TTYPE, GeoObject]

        thread_stations = route.thread_stations

        geo_objects = []

        if route.thread_type is TTYPE.aero:
            stations = [build_railway_stop(self._protobuf_data_provider.station_repo.get(ts)) for ts in thread_stations]
            geo_objects.extend(stations)
        else:
            for segment, stop_from, stop_to in self.segments_iter(route):
                if segment is None or segment[0] is None:
                    return None
                segment = np.frombuffer(segment, dtype=np.dtype('<i8')).reshape(-1, 2).cumsum(axis=0) / 1e11
                segment_points = [RoutePoint(x, y) for x, y in segment]
                geo_objects.append(build_railway_stop(self._protobuf_data_provider.station_repo.get(stop_from)))
                geo_objects.append(build_railway_polyline(segment_points))
            geo_objects.append(build_railway_stop(self._protobuf_data_provider.station_repo.get(thread_stations[-1])))

        thread_name = route.thread_name if route.thread_name else route.thread_title
        section = build_pathfinder_section(
            self._rasp_link, route, geo_objects, route.thread_uid, thread_name, route.thread_type.name, len(thread_stations)
        )

        return route.thread_type, section

    async def build_pathfinder_protobuf(self, variant):
        # type: (ProtobufBuilder, List[Variant]) -> tuple[List[RoutePoint], List[GeoObject]]

        points, sections = [], []
        maps_arrival = None

        for i, route in enumerate(variant.routes):
            departure_point = RoutePoint(route.departure_station['longitude'], route.departure_station['latitude'])
            arrival_point = RoutePoint(route.arrival_station['longitude'], route.arrival_station['latitude'])
            if route.is_maps_teleportation:
                prev_route, next_route = variant.routes[i - 1], variant.routes[i + 1]
                maps_dtm = prev_route.arrival_datetime
                maps_atm = next_route.departure_datetime - WAIT_TIMES[next_route.thread_type]
                maps_teleportation = await self._get_maps_teleportation(
                    departure_point, arrival_point,
                    maps_dtm, maps_atm
                )
                if maps_teleportation is None:
                    log.warning('No maps teleportation between {} and {}'.format(prev_route.thread_id, next_route.thread_id))
                    return None

                maps_segments, maps_arrival = maps_teleportation
                sections.extend(maps_segments)
            elif route.thread_id == 'NULL':
                maps_arrival = variant.routes[i-1].arrival_datetime
            else:
                if not points or points[-1] != departure_point:
                    points.append(departure_point)
                    sections.append(build_stop_protobuf(
                        departure_point,
                        route.departure_station_id,
                        route.departure_station['title'])
                    )

                section = self._parse_thread(route)
                if not section:
                    log.warning("Can't parse {}".format(route.thread_id))
                    return None
                thread_type, section = section

                if i != 0:
                    wait_time = (route.departure_datetime - maps_arrival).total_seconds()
                    sections.append(build_wait_section(wait_time))
                sections.append(section)

                if points[-1] != arrival_point:
                    points.append(arrival_point)
                    sections.append(build_stop_protobuf(
                        arrival_point,
                        route.arrival_station_id,
                        route.arrival_station['title'])
                    )
        return points, sections
