from typing import List

from marshmallow import fields

from maps_adv.common.protomallow import ProtobufSchema
from maps_adv.points.proto.forecasts_pb2 import (
    ForecastOutput,
    ForecastPointsInput,
    ForecastPolygonsInput,
)
from maps_adv.points.proto.primitives_pb2 import Point, Polygon
from maps_adv.points.server.lib.data_managers import ForecastsDataManager
from maps_adv.points.server.lib.exceptions import (
    InvalidPolygon,
    NonClosedPolygon,
    NoPointsPassed,
    NoPolygonsPassed,
)

from .base import with_schemas

__all__ = ["ForecastsApiProvider"]


def _validate_polygons(polygons: List[List[dict]]) -> None:
    if not polygons:
        raise NoPolygonsPassed


def _validate_points(points: List[List[dict]]) -> None:
    if not points:
        raise NoPointsPassed


def _validate_polygon(polygon: List[dict]) -> None:
    if len(polygon) < 4:
        raise InvalidPolygon

    if polygon[0] != polygon[-1]:
        raise NonClosedPolygon


class PointSchema(ProtobufSchema):
    class Meta:
        pb_message_class = Point

    longitude = fields.String()
    latitude = fields.String()


class PolygonSchema(ProtobufSchema):
    class Meta:
        pb_message_class = Polygon

    points = fields.List(fields.Nested(PointSchema), validate=_validate_polygon)


class ForecastPolygonsInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = ForecastPolygonsInput

    polygons = fields.List(fields.Nested(PolygonSchema), validate=_validate_polygons)


class ForecastPointsInputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = ForecastPointsInput

    points = fields.List(fields.Nested(PointSchema), validate=_validate_points)


class ForecastOutputSchema(ProtobufSchema):
    class Meta:
        pb_message_class = ForecastOutput

    shows = fields.Integer()


class ForecastsApiProvider:
    def __init__(self, dm: ForecastsDataManager):
        self._dm = dm

    @with_schemas(ForecastPolygonsInputSchema, ForecastOutputSchema)
    async def forecast_billboard(self, **kwargs) -> dict:
        polygon_list = [polygon["points"] for polygon in kwargs["polygons"]]

        got = await self._dm.forecast_billboard(polygons=polygon_list)

        return {"shows": got}

    @with_schemas(ForecastPolygonsInputSchema, ForecastOutputSchema)
    async def forecast_zerospeed(self, **kwargs) -> dict:
        polygon_list = [polygon["points"] for polygon in kwargs["polygons"]]

        got = await self._dm.forecast_zerospeed(polygons=polygon_list)

        return {"shows": got}

    @with_schemas(ForecastPointsInputSchema, ForecastOutputSchema)
    async def forecast_pins(self, **kwargs) -> dict:
        got = await self._dm.forecast_pins(**kwargs)

        return {"shows": got}

    @with_schemas(ForecastPolygonsInputSchema, ForecastOutputSchema)
    async def forecast_overview(self, **kwargs) -> dict:
        polygon_list = [polygon["points"] for polygon in kwargs["polygons"]]

        got = await self._dm.forecast_overview(polygons=polygon_list)

        return {"shows": got}
