from typing import List

from maps_adv.points.server.lib.db import DB
from maps_adv.points.server.lib.enums import PointType
from maps_adv.points.server.lib.exceptions import CollectionNotFound

__all__ = ["BasePointsDataManager", "PointsDataManager"]


class BasePointsDataManager:
    async def find_within_polygons(
        self, type_: PointType, version: int, polygons: List[dict]
    ) -> List[dict]:
        raise NotImplementedError()


class PointsDataManager(BasePointsDataManager):
    __slots__ = ("_db",)

    _db: DB

    def __init__(self, db: DB):
        self._db = db

    async def find_within_polygons(
        self, point_type: PointType, version: int, polygons: List[List[dict]]
    ) -> List[dict]:
        version = 1  # GEODISPLAY-1571
        async with self._db.acquire() as con:
            collection_id_sql = """
                SELECT id
                FROM collections
                WHERE type = $1 AND version = $2
                """

            collection_id = await con.fetchval(collection_id_sql, point_type, version)

            if not collection_id:
                raise CollectionNotFound(point_type, version)

            search_in_polygons_sql = """
                WITH input_geoms AS (
                    SELECT ST_UNION(ARRAY_AGG(ST_ShiftLongitude(g))) as union_geometry
                    FROM UNNEST($1::geometry[]) AS g
                )
                SELECT points_view.id,
                    points_view.longitude,
                    points_view.latitude
                FROM points_view JOIN input_geoms
                    ON st_within(points_view.geometry, input_geoms.union_geometry)
                WHERE points_view.collection_id = $2
                ORDER BY points_view.id ASC
                """

            got = await con.fetch(
                search_in_polygons_sql, self._format_polygons(polygons), collection_id
            )

        return [dict(el) for el in got]

    @staticmethod
    def _format_polygons(polygons: List[List[dict]]) -> list:
        polygons_ewkts = []
        for polygon in polygons:
            points = map(lambda el: f"{el['longitude']} {el['latitude']}", polygon)
            geometry = ", ".join(points)

            polygons_ewkts.append(f"SRID=4326;POLYGON (({geometry}))")

        return polygons_ewkts
