from datetime import datetime, timedelta
from typing import Any, Dict, List

from sqlalchemy import asc, desc, insert, Column

from src.common import ContinuationTokenManager


class SqlExecutor:
    def __init__(self, continuation_token_manager: ContinuationTokenManager = None) -> None:
        super().__init__()
        self._source_id_field_name = 'source_id'
        self._mutable_window = timedelta(hours=2)
        self._continuation_token_manager = continuation_token_manager or ContinuationTokenManager()

    async def get_last_processed_source_id(self, connection, table, source_date_field: str):
        rows = await connection.execute(
            f'SELECT MAX({self._source_id_field_name}) FROM {table.name} WHERE {source_date_field} < %s',
            (datetime.now() - self._mutable_window,),
        )

        result = []
        async for row in rows:
            result.append(row[0])

        return None if not result else result[0]

    async def select(
        self,
        connection,
        table,
        where_clauses: Dict[str, Dict[str, Any]],
        order_by: List[str],
        limit: int,
        continuation_token: str or None,
    ):
        request = table.select()

        if continuation_token:
            continuation_token_filter = self._continuation_token_manager.get_filter(table, continuation_token, order_by)
            request = request.where(continuation_token_filter)

        for key, column_filter in where_clauses.items():
            column = getattr(table.c, key)
            for op, value in column_filter.items():
                request = self._convert_where_clause(request, column, op, value)

        request = request.order_by(*await self._convert_sorting_statement(table, order_by)).limit(limit)

        async for row in await connection.execute(request):
            yield {column.name: row[column] for column in table.columns}

    def get_continuation_token(self, result: List[Dict[str, Any]], order_by: List[str], limit: int) -> str or None:
        return self._continuation_token_manager.get_continuation_token(result, order_by, limit)

    @staticmethod
    def _convert_where_clause(request, column: Column, op: str, value: Any):
        if op == 'gt':
            return request.where(column > value)

        if op == 'ge':
            return request.where(column >= value)

        if op == 'lt':
            return request.where(column < value)

        if op == 'le':
            return request.where(column <= value)

        if op == 'in':
            return request.where(column.in_(value))

        return request.where(column == value)

    @staticmethod
    async def _convert_sorting_statement(table, order_by: List[str]):
        return [asc(getattr(table.c, x)) if x[0] != '-' else desc(getattr(table.c, x[1:])) for x in order_by]

    async def insert(self, connection, table, row):
        return await connection.execute(
            insert(table)
            .values(**row)
            .on_conflict_do_update(index_elements=[self._source_id_field_name], set_=row)
        )
