import asyncio
import concurrent.futures
from operator import itemgetter
from typing import List, Optional

from yql.api.v1.client import YqlClient, YqlTableReadIterator

from maps_adv.points.server.lib.db import DB

__all__ = ["BaseForecastsDataManager", "ForecastsDataManager", "YtSyncIsNotConfigured"]


_unpack_yt_config = itemgetter("yt_token", "yt_table", "yt_cluster")


class YtSyncIsNotConfigured(Exception):
    pass


class BaseForecastsDataManager:
    async def forecast_billboard(self, polygons: List[dict]) -> int:
        raise NotImplementedError()

    async def forecast_zerospeed(self, polygons: List[dict]) -> int:
        raise NotImplementedError()

    async def forecast_pins(self, polygons: List[dict]) -> int:
        raise NotImplementedError()

    async def sync_forecasts(self) -> None:
        raise NotImplementedError()


class ForecastsDataManager(BaseForecastsDataManager):
    __slots__ = "_db", "yt_config"

    _db: DB
    yt_config: Optional[dict]

    def __init__(self, db: DB, *, yt_config: Optional[dict] = None):
        self._db = db

        self.yt_config = yt_config or {}

    async def forecast_billboard(self, polygons: List[List[dict]]) -> int:
        sql = """
            WITH input_geoms AS (
                SELECT ST_UNION(ARRAY_AGG(ST_ShiftLongitude(g))) as union_geometry
                FROM UNNEST($1::geometry[]) AS g
            )
            SELECT COALESCE(SUM(billboard_shows), 0)
            FROM shows_forecasts JOIN input_geoms
            ON ST_Contains(
                input_geoms.union_geometry,
                ST_ShiftLongitude(
                    ST_SetSRID(ST_PointFromGeoHash(shows_forecasts.geohash), 4326)
                )
            )
        """

        async with self._db.acquire() as con:
            got = await con.fetchval(sql, self._format_polygons(polygons))

            return got

    async def forecast_zerospeed(self, polygons: List[dict]) -> int:
        sql = """
            WITH input_geoms AS (
                SELECT ST_UNION(ARRAY_AGG(ST_ShiftLongitude(g))) as union_geometry
                FROM UNNEST($1::geometry[]) AS g
            )
            SELECT COALESCE(SUM(zsb_shows), 0)
            FROM shows_forecasts JOIN input_geoms
            ON ST_Contains(
                input_geoms.union_geometry,
                ST_ShiftLongitude(
                    ST_SetSRID(ST_PointFromGeoHash(shows_forecasts.geohash), 4326)
                )
            )
        """

        async with self._db.acquire() as con:
            got = await con.fetchval(sql, self._format_polygons(polygons))

            return got

    async def forecast_overview(self, polygons: List[dict]) -> int:
        sql = """
            WITH input_geoms AS (
                SELECT ST_UNION(ARRAY_AGG(ST_ShiftLongitude(g))) as union_geometry
                FROM UNNEST($1::geometry[]) AS g
            )
            SELECT COALESCE(SUM(overview_shows), 0)
            FROM shows_forecasts JOIN input_geoms
            ON ST_Contains(
                input_geoms.union_geometry,
                ST_ShiftLongitude(
                    ST_SetSRID(ST_PointFromGeoHash(shows_forecasts.geohash), 4326)
                )
            )
        """

        async with self._db.acquire() as con:
            got = await con.fetchval(sql, self._format_polygons(polygons))

            return got

    async def forecast_pins(self, points: List[dict]) -> int:
        sql = """
            WITH input_points AS (
                SELECT DISTINCT ST_GeoHash(g, 6) as geohash
                FROM UNNEST($1::geometry[]) AS g
            )
            SELECT COALESCE(SUM(pin_shows), 0)
            FROM shows_forecasts JOIN input_points
                ON shows_forecasts.geohash = input_points.geohash
        """

        async with self._db.acquire() as con:
            got = await con.fetchval(
                sql,
                [
                    f"SRID=4326;POINT({point['longitude']} {point['latitude']})"
                    for point in points
                ],
            )

            return got

    async def sync_forecasts(self) -> None:
        try:
            yt_token, yt_table, yt_cluster = _unpack_yt_config(self.yt_config)
        except KeyError:
            raise YtSyncIsNotConfigured()
        if not all([yt_token, yt_table, yt_cluster]):
            raise YtSyncIsNotConfigured()

        def _read_yt_table() -> List[List[str]]:
            with YqlClient(token=yt_token):
                table_iterator = YqlTableReadIterator(
                    yt_table,
                    cluster=yt_cluster,
                    column_names=[
                        "geohash",
                        "pin_shows",
                        "billboard_shows",
                        "zsb_shows",
                        "overview_shows",
                    ],
                )

                return list(table_iterator)

        with concurrent.futures.ThreadPoolExecutor() as pool:
            records = await asyncio.get_event_loop().run_in_executor(
                pool, _read_yt_table
            )

        if not len(records):
            return

        async with self._db.acquire() as con:
            async with con.transaction():
                await con.execute("DELETE FROM shows_forecasts")
                await con.copy_records_to_table("shows_forecasts", records=records)

    @staticmethod
    def _format_polygons(polygons: List[List[dict]]) -> list:
        polygons_ewkts = []
        for polygon in polygons:
            points = map(lambda el: f"{el['longitude']} {el['latitude']}", polygon)
            geometry = ", ".join(points)

            polygons_ewkts.append(f"SRID=4326;POLYGON (({geometry}))")

        return polygons_ewkts
