from operator import itemgetter
from typing import List

from marshmallow import fields
from marshmallow_enum import EnumField

from maps_adv.common.protomallow import ProtobufSchema
from maps_adv.points.proto import points_in_polygons_pb2, primitives_pb2
from maps_adv.points.server.lib.data_managers import PointsDataManager
from maps_adv.points.server.lib.enums import PointType
from maps_adv.points.server.lib.exceptions import (
    InvalidPolygon,
    InvalidVersion,
    NonClosedPolygon,
    NoPolygonsPassed,
)

from .base import StrictSchema, with_schemas

__all__ = ["PointsApiProvider"]


def _validate_polygons(polygons: List[List[dict]]) -> None:
    if not polygons:
        raise NoPolygonsPassed


def _validate_polygon(polygon: List[dict]) -> None:
    if len(polygon) < 4:
        raise InvalidPolygon

    if polygon[0] != polygon[-1]:
        raise NonClosedPolygon


def _validate_version(version: int) -> None:
    if version < 1:
        raise InvalidVersion


_destruct_parameters = itemgetter("point_type", "version")


class PolygonPoint(ProtobufSchema):
    class Meta:
        pb_message_class = primitives_pb2.Point

    longitude = fields.String()
    latitude = fields.String()


class Polygon(ProtobufSchema):
    class Meta:
        pb_message_class = primitives_pb2.Polygon

    points = fields.List(fields.Nested(PolygonPoint), validate=_validate_polygon)


class PointsInPolygonInput(ProtobufSchema):
    class Meta:
        pb_message_class = points_in_polygons_pb2.PointsInPolygonsInput

    polygons = fields.List(fields.Nested(Polygon), validate=_validate_polygons)


class ParametersSchema(StrictSchema):
    point_type = EnumField(PointType, required=True)
    version = fields.Integer(required=True, validate=_validate_version)


class PointsApiProvider:
    def __init__(self, dm: PointsDataManager):
        self._dm = dm

    @with_schemas(input_schema=PointsInPolygonInput)
    async def find_within_polygons(self, polygons, **kwargs) -> dict:
        parameters = ParametersSchema().load(kwargs)
        point_type, version = _destruct_parameters(parameters.data)

        polygon_list = [polygon["points"] for polygon in polygons]

        got = await self._dm.find_within_polygons(
            point_type=point_type, version=version, polygons=polygon_list
        )

        points = list(map(lambda kw: primitives_pb2.IdentifiedPoint(**kw), got))
        return points_in_polygons_pb2.PointsInPolygonsOutput(
            points=points
        ).SerializeToString()
