import math
import json
import itertools

import maps.wikimap.mapspro.services.mrc.eye.experiments.speed_limits.pylibs.road_graph_consts as rg_const
import maps.wikimap.mapspro.services.mrc.eye.experiments.speed_limits.pylibs.road_graph_utils as rg_utils


from collections import namedtuple

Z_LEVELS = [-2, -1, 0, 1, 2, 3]
PROHIBITED_DIRECTION_FORWARD = 0x1
PROHIBITED_DIRECTION_LEFT = 0x2
PROHIBITED_DIRECTION_RIGHT = 0x4


ProhibitedType = namedtuple('ProhibitedType', 'direction, only_truck, pass_single_edge')


PROHIBITED_DICTIONARY = {
    "mandatory_proceed_straight": ProhibitedType(PROHIBITED_DIRECTION_LEFT | PROHIBITED_DIRECTION_RIGHT, False, True),
    "mandatory_proceed_straight_or_turn_right" : ProhibitedType(PROHIBITED_DIRECTION_LEFT, False, True),
    "mandatory_proceed_straight_or_turn_left" : ProhibitedType(PROHIBITED_DIRECTION_RIGHT, False, True),
    "mandatory_turn_right_ahead" : ProhibitedType(PROHIBITED_DIRECTION_LEFT | PROHIBITED_DIRECTION_FORWARD, False, True),
    "mandatory_turn_left_ahead" : ProhibitedType(PROHIBITED_DIRECTION_RIGHT | PROHIBITED_DIRECTION_FORWARD, False, True),
    "mandatory_turn_right_or_left": ProhibitedType(PROHIBITED_DIRECTION_FORWARD, False, True),
    "prohibitory_no_right_turn" : ProhibitedType(PROHIBITED_DIRECTION_RIGHT, False, True),
    "prohibitory_no_left_turn": ProhibitedType(PROHIBITED_DIRECTION_LEFT, False, True),
    "prohibitory_max_height": ProhibitedType(PROHIBITED_DIRECTION_FORWARD, True, False),
}
# эти пока не используем:
#    prohibitory_no_uturn
#    prohibitory_no_entry
#    prohibitory_no_vehicles
#    prohibitory_no_heavy_goods_vehicles


def edge_str(edge):
    return str(edge["id"]) + (" forward" if edge["forward"] else " backward")


class ManeuversWorkerBase(object):
    def __init__(self, nodes, edges, no_uturn=True):
        self._make_all_maneuvers(nodes, edges, no_uturn)

    def _make_empty_node_to_edges_dict(self, nodes):
        node2edges = {}
        for node in nodes:
            node2edges[node["id"]] = {}
            for z_lev in Z_LEVELS:
                node2edges[node["id"]][z_lev] = {"in": [], "out": []}
        return node2edges

    def _fill_node_to_edges(self, nodes, edges):
        node2edges = self._make_empty_node_to_edges_dict(nodes)
        for edge in edges:
            f_rd_jc_id = edge["f_rd_jc_id"]
            f_zlev = edge["f_zlev"]
            t_rd_jc_id = edge["t_rd_jc_id"]
            t_zlev = edge["t_zlev"]
            edge_f = {"id": edge["id"], "forward": True, "access_id": edge["access_id"], "valid": (0 != (edge["oneway"] & rg_const.DIRECTION_FORWARD))}
            node2edges[f_rd_jc_id][f_zlev]["out"].append(edge_f)
            node2edges[t_rd_jc_id][t_zlev]["in"].append(edge_f)

            edge_b = {"id": edge["id"], "forward": False, "access_id": edge["access_id"], "valid": (0 != (edge["oneway"] & rg_const.DIRECTION_BACKWARD))}
            node2edges[t_rd_jc_id][t_zlev]["out"].append(edge_b)
            node2edges[f_rd_jc_id][f_zlev]["in"].append(edge_b)
        return node2edges

    def _make_all_maneuvers(self, nodes, edges, no_uturn):
        '''
             maneuvers = {node_id: [{"in": edge_in, "out": edge_out, "valid":True},...])}
        '''

        node2edges = self._fill_node_to_edges(nodes, edges)
        self._maneuvers = {}
        for node_id, data in node2edges.items():
            self._maneuvers[node_id] = []
            for z_lev in Z_LEVELS:
                lvl_data = data[z_lev]
                for edge_in, edge_out in itertools.product(lvl_data["in"], lvl_data["out"]):
                    if no_uturn and (edge_in["id"] == edge_out["id"]):
                        continue
                    access_id = (edge_in["access_id"] & edge_out["access_id"])
                    allow_maneuver = edge_in["valid"] and edge_out["valid"]
                    valid = (0 != (access_id & rg_const.AID_CAR)) and allow_maneuver
                    truck_valid = (0 != (access_id & rg_const.AID_TRUCK)) and allow_maneuver
                    self._maneuvers[node_id].append({"edge_in": edge_in, "edge_out": edge_out, "valid": valid, "truck_valid": truck_valid})

    @property
    def maneuvers(self):
        return self._maneuvers


class ManeuversWorkerConditions(ManeuversWorkerBase):
    def __init__(self, nodes, edges, conds):
        super(ManeuversWorkerConditions, self).__init__(nodes, edges)
        self._apply_conditions(conds)

    def _apply_conditions(self, conds):
        USED_CONDITIONS = [1]

        for cond in conds:
            if cond["cond_type"] not in USED_CONDITIONS:
                continue
            if 1 < len(cond["to_edges"]):
                print("Unable to apply condition with " + str(len(cond["to_edges"])) + "' to edges'")
                continue
            node_id = cond["via_node"]
            if node_id not in self.maneuvers:
                print("Unable to apply condition " + str(cond["id"]) + " to node with id: " + str(node_id))
                continue
            edge_in_id = cond["from_edge"]
            edge_out_id = cond["to_edges"][0]
            maneuv = self.maneuvers[node_id]
            for edge_pair in maneuv:
                edge_in = edge_pair["edge_in"]
                edge_out = edge_pair["edge_out"]
                if edge_in_id == edge_in["id"] and edge_out_id == edge_out["id"]:
                    edge_pair["cond_id"] = cond["id"]
                    if rg_const.AID_TRUCK == cond["access_id"]:
                        edge_pair["truck_valid"] = False
                    else:
                        edge_pair["truck_valid"] = False
                        edge_pair["valid"] = False
                    break


class ManeuversWorkerSigns(ManeuversWorkerBase):
    def __init__(self, nodes, edges, signs):
        super(ManeuversWorkerSigns, self).__init__(nodes, edges)
        self._signs_dict = {x["id"]: x for x in signs}
        self._edge_by_id = {x["id"]: x for x in edges}
        self._apply_signs(edges, signs)

    def _fill_maneuvers_angle(self, edges):
        '''
             добавляем в каждую пару рёбер "angle" в радианах
             angle - угол между направлениями входного и входного рёбер от -Pi до Pi,
                    рёбра откладываются от одной точки, угол откладывается от edge_in
                    (последнего отрезка ломанной)
                    если угол отрицательный - поворачиваем направо, если положительный - налево
        '''
        def calc_angle(s1, s2):
            dx = s1[1][0] - s2[0][0]
            dy = s1[1][1] - s2[0][1]
            assert dx*dx + dy*dy < 1e-7, "Invalid points"

            x1 = s1[1][0] - s1[0][0]
            y1 = s1[1][1] - s1[0][1]
            x2 = s2[1][0] - s2[0][0]
            y2 = s2[1][1] - s2[0][1]
            angle = math.atan2(y2, x2) - math.atan2(y1, x1)
            if angle > math.pi:
                angle -= 2 * math.pi
            elif angle <= -math.pi:
                angle += 2 * math.pi
            return angle

        edges_dict = {x["id"]: x for x in edges}
        for node_id, edge_pairs in self.maneuvers.items():
            if (0 == len(edge_pairs)):
                continue
            for edge_pair in edge_pairs:
                edge_pair_in = edge_pair["edge_in"]
                edge_pair_out = edge_pair["edge_out"]
                # если входное ребро проходим в прямом направлении
                # то надо брать последние две его точки, чтобы получить
                # отрезок "входящий" в вершину
                #   pt[-2]->pt[-1]
                # если в обратном то надо взять первые две точки, причем
                # в порядке
                #   pt[1] -> pt[0]
                edge_in = edges_dict[edge_pair_in["id"]]
                if edge_pair_in["forward"]:
                    edge_in_segment = edge_in["points"][-2:]
                else:
                    edge_in_segment = edge_in["points"][:2][::-1]

                # для выходного ребра
                # если проходим в прямом направлении берем pt[0]->pt[1]
                # если в обратном берем pt[-1]->pt[-2]
                edge_pair_out = edge_pair["edge_out"]
                edge_out = edges_dict[edge_pair_out["id"]]
                if edge_pair_out["forward"]:
                    edge_out_segment = edge_out["points"][:2]
                else:
                    edge_out_segment = edge_out["points"][-2:][::-1]
                edge_pair["angle"] = calc_angle(edge_in_segment, edge_out_segment)

    def _get_last_sign(self, sign_ids):
        if (len(sign_ids) == 0):
            return None, None, None
        last_sign_id = None
        last_sign_type = None
        last_time = rg_utils.convertSQLDateTimeToTimestamp("2000-01-01 00:00:01.00000000+00:00")
        last_is_truck = False
        for sign_id in sign_ids:
            if (sign_id not in self._signs_dict):
                continue
            sign = self._signs_dict[sign_id]
            sign_type = sign["sign_type"]
            if (sign_type not in PROHIBITED_DICTIONARY):
                continue
            to_time = rg_utils.convertSQLDateTimeToTimestamp(sign["to_time"])
            if last_time < to_time:
                last_sign_id = sign["id"]
                last_sign_type = sign_type
                last_time = to_time
                last_is_truck = "information_heavy_vehicle" in sign["slaves"]
        return last_sign_id, last_sign_type, last_is_truck

    def _prohibite_direction(self, sign_id, maneuvers, direction, only_truck):
        '''
            maneuvers = [{"in": edge_in, "out": edge_out, "valid":True, "angle": radians}, ...]
        '''
        FORWARD_ANGLE_EPSILON = 30. * math.pi / 180.

        forward_maneuver = None
        forward_angle = FORWARD_ANGLE_EPSILON
        for maneuver in maneuvers:
            if abs(maneuver["angle"]) < forward_angle:
                forward_angle = abs(maneuver["angle"])
                forward_maneuver = maneuver

        if forward_maneuver is not None and (0 != (direction & PROHIBITED_DIRECTION_FORWARD)):
            forward_maneuver["truck_valid"] = False
            forward_maneuver["valid"] = only_truck
            forward_maneuver["sign_id"] = sign_id

        for maneuver in maneuvers:
            if maneuver["angle"] > forward_angle and (0 != (direction & PROHIBITED_DIRECTION_LEFT)):
                maneuver["truck_valid"] = False
                maneuver["valid"] = only_truck
                maneuver["sign_id"] = sign_id
            if maneuver["angle"] < -forward_angle and (0 != (direction & PROHIBITED_DIRECTION_RIGHT)):
                maneuver["truck_valid"] = False
                maneuver["valid"] = only_truck
                maneuver["sign_id"] = sign_id

    def _apply_sign(self, sign_id, sign_type, edge_maneuvers, only_truck):
        if sign_type in PROHIBITED_DICTIONARY:
            proh = PROHIBITED_DICTIONARY[sign_type]
            proh_direction = proh.direction
            only_truck = only_truck or proh.only_truck
            self._prohibite_direction(sign_id, edge_maneuvers, proh_direction, only_truck)
        else:
            print("Not found prohibited direction for sign type: " + sign_type)

    def _pass_edges(self, edge_maneuvers, truck):
        MAX_PASSED_EDGES = 2

        assert 0 < len(edge_maneuvers), "Invalid maneuvers count"

        if 1 < len(edge_maneuvers):
            return edge_maneuvers
        rest_passed_edges = MAX_PASSED_EDGES
        while 0 < rest_passed_edges:
            edge_out_m = edge_maneuvers[0]["edge_out"]
            edge_out = self._edge_by_id[edge_out_m["id"]]
            if edge_out_m["forward"]:
                _, sign_type, is_truck = self._get_last_sign(edge_out["sign_ids_f"])
            else:
                _, sign_type, is_truck = self._get_last_sign(edge_out["sign_ids_t"])
            if (sign_type is not None) and (truck == is_truck):
                return None
            edge_maneuvers = self._maneuvers_by_edge_in.get((edge_out_m["id"], edge_out_m["forward"]))
            if (edge_maneuvers is None) or (0 == len(edge_maneuvers)):
                return None
            if (1 < len(edge_maneuvers)):
                return edge_maneuvers
            rest_passed_edges -= 1
        return None

    def _apply_signs(self, edges, signs):
        self._fill_maneuvers_angle(edges)
        self._maneuvers_by_edge_in = {}
        # ключ пара (id, forward)
        for edge_pairs in self._maneuvers.values():
            for edge_pair in edge_pairs:
                edge_in = edge_pair["edge_in"]
                key = (edge_in["id"], edge_in["forward"])
                if key in self._maneuvers_by_edge_in:
                    self._maneuvers_by_edge_in[key].append(edge_pair)
                else:
                    self._maneuvers_by_edge_in[key] = [edge_pair]

        for edge in edges:
            if (0 != (edge["oneway"] & rg_const.DIRECTION_FORWARD)):
                sign_id, sign_type, is_truck = self._get_last_sign(edge["sign_ids_f"])
                if (sign_type is not None):
                    edge_maneuvers = self._maneuvers_by_edge_in.get((edge["id"], True))
                    if edge_maneuvers is None or 0 == len(edge_maneuvers):
                        print("There are no maneuvers from forward edge_id: " + str(edge["id"]) + " with sign")
                    elif (1 == len(edge_maneuvers)) and PROHIBITED_DICTIONARY[sign_type].pass_single_edge:
                        # только одно ребро на выход в манёврах, по идеи ничего запретить мы не можем
                        # попробуем перейти на него и если там знака нет, то сдвигаемся на него и запрещаем
                        # манёвры на нём.
                        edge_maneuvers = self._pass_edges(edge_maneuvers, is_truck)
                        if edge_maneuvers is not None:
                            self._apply_sign(sign_id, sign_type, edge_maneuvers, is_truck)
                    else:
                        self._apply_sign(sign_id, sign_type, edge_maneuvers, is_truck)

            if (0 != (edge["oneway"] & rg_const.DIRECTION_BACKWARD)):
                sign_id, sign_type, is_truck = self._get_last_sign(edge["sign_ids_t"])
                if (sign_type is not None):
                    edge_maneuvers = self._maneuvers_by_edge_in.get((edge["id"], False))
                    if edge_maneuvers is None or 0 == len(edge_maneuvers):
                        print("There are no maneuvers from backward edge_id: " + str(edge["id"]) + " with sign")
                    elif (1 == len(edge_maneuvers)) and PROHIBITED_DICTIONARY[sign_type].pass_single_edge:
                        edge_maneuvers = self._pass_edges(edge_maneuvers, is_truck)
                        if edge_maneuvers is not None:
                            self._apply_sign(sign_id, sign_type, edge_maneuvers, is_truck)
                    else:
                        self._apply_sign(sign_id, sign_type, edge_maneuvers, is_truck)


def find_diff_maneuvers(mcond, msign):
    diff = {}
    for node_id, edge_pairs_cond in mcond.items():
        edge_pairs_sign = msign.get(node_id, [])
        diff[node_id] = []
        for edge_pair_cond in edge_pairs_cond:
            edge_in_cond = edge_pair_cond["edge_in"]
            edge_out_cond = edge_pair_cond["edge_out"]
            found = False
            for edge_pair_sign in edge_pairs_sign:
                edge_in_sign = edge_pair_sign["edge_in"]
                edge_out_sign = edge_pair_sign["edge_out"]
                if edge_in_cond == edge_in_sign and edge_out_cond == edge_out_sign:
                    found = True
                    if edge_pair_cond["valid"] != edge_pair_sign["valid"]:
                        diff_pair = {"edge_in": edge_in_cond, "edge_out": edge_out_cond, "angle": edge_pair_sign["angle"]}
                        diff_pair["valid_cond"] = edge_pair_cond["valid"]
                        diff_pair["valid_sign"] = edge_pair_sign["valid"]
                        diff_pair["only_truck"] = False
                        if "cond_id" in edge_pair_cond:
                            diff_pair["cond_id"] = edge_pair_cond["cond_id"]
                        if "sign_id" in edge_pair_sign:
                            diff_pair["sign_id"] = edge_pair_sign["sign_id"]
                        diff[node_id].append(diff_pair)
                    elif edge_pair_cond["truck_valid"] != edge_pair_sign["truck_valid"]:
                        diff_pair = {"edge_in": edge_in_cond, "edge_out": edge_out_cond, "angle": edge_pair_sign["angle"]}
                        diff_pair["valid_cond"] = edge_pair_cond["truck_valid"]
                        diff_pair["valid_sign"] = edge_pair_sign["truck_valid"]
                        diff_pair["only_truck"] = True
                        if "cond_id" in edge_pair_cond:
                            diff_pair["cond_id"] = edge_pair_cond["cond_id"]
                        if "sign_id" in edge_pair_sign:
                            diff_pair["sign_id"] = edge_pair_sign["sign_id"]
                        diff[node_id].append(diff_pair)
                    break
            assert found, "Unable to found conditional maneuver from " + edge_str(edge_in_sign) + " to " + edge_str(edge_out_sign) + " in signs maneuvers"
    return diff


def save_maneuvers(maneuvers, json_path):
    with open(json_path, 'w',  encoding='utf-8') as f:
        json.dump({"maneuvers": maneuvers}, f, indent=4, ensure_ascii=False)


def valid_maneuvers_cnt(maneuvers):
    cnt = 0
    for node_id, edge_pairs in maneuvers.items():
        if (0 == len(edge_pairs)):
            continue
        for edge_pair in edge_pairs:
            if edge_pair["valid"]:
                cnt += 1
    return cnt


def maneuvers_cnt(maneuvers):
    cnt = 0
    for node_id, edge_pairs in maneuvers.items():
        cnt += len(edge_pairs)
    return cnt


def print_maneuvers(maneuvers, valid_only=True):
    for node_id, edge_pairs in maneuvers.items():
        if (0 == len(edge_pairs)):
            continue
        print("Node: ", node_id)
        for edge_pair in edge_pairs:
            edge_in = edge_pair["edge_in"]
            edge_out = edge_pair["edge_out"]
            valid = edge_pair["valid"]
            angle = ""
            if "angle" in edge_pair:
                angle = ", angle: " + str(180 * edge_pair["angle"] / math.pi)
            if valid_only and not valid:
                continue
            print("    from edge: " + edge_str(edge_in) + " to edge: " + edge_str(edge_out) + (" valid" if valid else " not valid") + " maneuver" + angle)


def print_difference(diff):
    diff_cnt_cond = 0
    diff_cnt_sign = 0
    for node_id, edge_pairs in diff.items():
        if (0 == len(edge_pairs)):
            continue
        print("Node: ", node_id)
        for edge_pair in edge_pairs:
            edge_in = edge_pair["edge_in"]
            edge_out = edge_pair["edge_out"]
            valid_cond = ", valid_cond: " + str(edge_pair.get("valid_cond", "undefined"))
            valid_sign = ", valid_sign: " + str(edge_pair.get("valid_sign", "undefined"))
            if not edge_pair.get("valid_cond", True):
                diff_cnt_cond += 1
            if not edge_pair.get("valid_sign", True):
                diff_cnt_sign += 1
            id = ""
            if "sign_id" in edge_pair:
                id = ", sign_id: " + str(edge_pair["sign_id"])
            if "cond_id" in edge_pair:
                id = ", cond_id: " + str(edge_pair["cond_id"])
            print("    from edge: " + edge_str(edge_in) + " to edge: " + edge_str(edge_out) + valid_cond + valid_sign + id + (" only truck" if edge_pair["only_truck"] else ""))

    print("Invalid by condition, valid by sign: ", diff_cnt_cond)
    print("Invalid by sign, valid by condition: ", diff_cnt_sign)
