# -*- coding: UTF-8 -*-
import os
import sys
import copy

from yandex.maps.geolib3 import Point2, BoundingBox
from yandex.maps.jams.graph4 import Graph, PersistentIndex, EdgesRTree
import yandex.maps.road_graph as road_graph

from analytics.geo.tools.geometry.lib.coords_lib import (
    bounds_for_points, bounds_enlarged_meters,
    distance_coords_near, is_close_coord_to_segment,
    point_in_polygon
)


class EdgeData:
    def get_graph_path(self, data_type):
        return self.graph_path + self.graph_version + '/' + self.type_dict[data_type]

    def __init__(self, graph_version=None):

        self.graph_path = '/var/spool/yandex/maps/graph/'
        self.type_dict = {'persistent_index': 'edges_persistent_index.mms.1',
                     'topology': 'topology.mms.2',
                     'data': 'data.mms.2',
                     'edges_rtree': 'edges_rtree.mms.2',
                     'road_graph': 'road_graph.fb',
                     }

        if graph_version is None:
            for filename in os.listdir(self.graph_path):
                if not filename[0].isdigit() or filename.endswith('testing'):
                    continue
                if not all([os.path.exists(os.path.join(self.graph_path, filename, f))
                            for f in self.type_dict.values()]):
                    continue
                if graph_version is None or graph_version < filename:
                    graph_version = filename

            print('EdgeData found graph: {}'.format(graph_version))
            assert graph_version is not None
        self.graph_version = graph_version

        self.persistent_index = PersistentIndex(self.get_graph_path('persistent_index'))
        self.graph = Graph(self.get_graph_path('topology'), self.get_graph_path('data'), False)
        self.edges_rtree = EdgesRTree(self.get_graph_path('edges_rtree'))
        self.road_graph = None

    def get_short_ids(self, persistent_id_list):
        return [self.persistent_index.find_short_id(pid)
                for pid in persistent_id_list]

    def get_persistent_ids(self, short_id_list):
        return [self.persistent_index.find_long_id(self.graph.edges_index().base(sid))
                for sid in short_id_list]

    def get_edge_geometry(self, persistent_id):
        edge_id = self.persistent_index.find_short_id(int(persistent_id))
        edge = self.graph.edge_data(edge_id)
        segments = []
        for i in range(edge.segments_number):
            segment = edge.segment_at(i)
            segments.append((segment.start.lat, segment.start.lon))
        segments.append((segment.end.lat, segment.end.lon))
        return segments

    def get_edge_length(self, persistent_id=None, short_id=None):
        return self.get_edge_data(persistent_id, short_id).length

    def get_edge_data(self, persistent_id=None, short_id=None):
        if short_id is None:
            short_id = self.persistent_index.find_short_id(int(persistent_id))
        return self.graph.edge_data(short_id)
#         return edge.length

    def get_in_edges_short(self, edge_id):
        """Найти входящие рёбрк по данному edge_id
        Эта штука принципиально может работать только с коротким edge_id,
        а не persistent_id, т.к. persistent_id завязаны на геометрию,
        а возможно, что с одной геометрией будет 2 разных ребра.
        """
        edge = self.graph.edge(edge_id)
        edges_in = [epp.id for epp in self.graph.in_edges(edge.source)]

        if self.road_graph is None:
            self.road_graph = road_graph.RoadGraph(self.get_graph_path('road_graph'))

        auto_access_id = road_graph.AccessId.Automobile
        edges_in = [from_eid for from_eid in edges_in
                    if not self.road_graph.is_forbidden_turn(from_eid, edge_id, auto_access_id)]
        return edges_in

    def get_out_edges_short(self, edge_id):
        edge = self.graph.edge(edge_id)
        edges_out = [epp.id for epp in self.graph.out_edges(edge.target)]

        if self.road_graph is None:
            self.road_graph = road_graph.RoadGraph(self.get_graph_path('road_graph'))

        auto_access_id = road_graph.AccessId.Automobile
        edges_out = [to_eid for to_eid in edges_out
                    if not self.road_graph.is_forbidden_turn(edge_id, to_eid, auto_access_id)]
        return edges_out

    def get_edge_segments_list(self, persistent_id):
        edge_id = self.persistent_index.find_short_id(persistent_id)
        edge = self.graph.edge_data(edge_id)
        segments = [edge.segment_at(i) for i in range(edge.segments_number)]
        return segments


    def get_edges_in_bounds(self, bounds, return_short=False):
        edges = self.edges_rtree.edges_in_window(
            self.graph,
            BoundingBox(
                Point2(bounds[0][1], bounds[0][0]),
                Point2(bounds[1][1], bounds[1][0])
        ))
        if return_short:
            return list(edges)
        persistent_edges = [self.persistent_index.find_long_id(edge_id)
                            for edge_id in edges]
        return persistent_edges

    def get_edges_in_circle(self, center, radius):
        def edge_in_circle(edge):
            p0 = None
            for p in self.get_edge_geometry(edge):
                if p0:
                    if is_close_coord_to_segment(center, p0, p, radius):
                        return True
                p0 = p
            return False

        bounds = bounds_enlarged_meters(
            [copy.copy(center), copy.copy(center)],
            radius + 250 # кажется, сегменты не бывают длиннее 200 метров, добавил ещё 50
        )
        edges = [e for e in self.get_edges_in_bounds(bounds)
                 if e is not None and edge_in_circle(e)]

        return edges

    def get_edges_in_polygon(self, poly, catch='all'):
        """
        catch
            - all -- значит берём рёбра, хоть одна часть которого в полигоне
            - cross_in -- только те, которые входят внутрь и пересекают границу (правильнее будет переименовать)
            - cross_out -- только те, которые входят внутрь и пересекают границу (правильнее будет переименовать)
        """
        def edge_in_poly(edge, catch_cross=False):
            def iterate_poly_segments(poly):
                if len(poly)<2:
                    return
                yield (poly[-1], poly[0])
                for i in range(len(poly)-1):
                    yield (poly[i], poly[i+1])

            def segments_crosses(segment1, segment2):
                def normal_projection(segment, point):
                    return ((point[0] - segment[0][0]) * (segment[1][1] - segment[0][1])
                          - (point[1] - segment[0][1]) * (segment[1][0] - segment[0][0]))

                if (normal_projection(segment1, segment2[0]) * normal_projection(segment1, segment2[1]) < 0
                        and (normal_projection(segment2, segment1[0]) * normal_projection(segment2, segment1[1]) < 0)):
                    return True

                return False


            points = self.get_edge_geometry(edge)
            for p in points:
                if point_in_polygon(p, poly):
                    return True

            if catch_cross:
                for i in range(len(points)-1):
                    for poly_segment in iterate_poly_segments(poly):
                        if segments_crosses(
                                    (points[i], points[i+1]),
                                    poly_segment,
                                ):
                            return True

            return False


        assert catch in ['all', 'cross_in', 'cross_out']

        bounds = bounds_for_points(poly)
        bounds = bounds_enlarged_meters(bounds, 250) # кажется, сегменты не бывают короче 200 метров, добавил ещё 50
        edges = [e for e in self.get_edges_in_bounds(bounds) if e is not None]
        try:
            if catch=='all':
                edges = [e for e in edges
                         if edge_in_poly(e, catch_cross=True)]
            elif catch=='cross_in':
                new_edges = []
                for e in edges:
                    geometry = self.get_edge_geometry(e)
                    if (not point_in_polygon(geometry[0], poly)
                        and point_in_polygon(geometry[-1], poly)
                    ):
                        new_edges.append(e)
                edges = new_edges
            elif catch=='cross_out':
                new_edges = []
                for e in edges:
                    geometry = self.get_edge_geometry(e)
                    if (point_in_polygon(geometry[0], poly)
                        and not point_in_polygon(geometry[-1], poly)
                    ):
                        new_edges.append(e)
                edges = new_edges
        except:
            print(edges)
            raise

        return edges

if __name__ == '__main__':
    pass
