from datetime import datetime
from typing import AsyncIterable, Iterable, Mapping, Optional

from sqlalchemy import func, select, update
from sqlalchemy.dialects.postgresql import insert

from mail.payments.payments.core.entities.worker import Worker, WorkerState
from mail.payments.payments.storage.db.tables import workers as t_workers
from mail.payments.payments.storage.mappers.base import BaseMapper


class WorkerMapper(BaseMapper):
    name = 'worker'
    model = Worker

    @staticmethod
    def map(row: Mapping) -> Worker:
        return Worker(
            worker_id=row['worker_id'],
            worker_type=row['worker_type'],
            host=row['host'],
            state=row['state'],
            heartbeat=row['heartbeat'],
            startup=row['startup'],
            task_id=row['task_id'],
        )

    @staticmethod
    def unmap(obj: Worker) -> dict:
        return {
            'state': obj.state,
            'task_id': obj.task_id,
        }

    async def create(self, obj: Worker) -> Worker:
        unmapped = self.unmap(obj)
        query = (
            insert(t_workers).
            values(
                worker_id=obj.worker_id,
                worker_type=obj.worker_type,
                host=obj.host,
                startup=obj.startup,
                **unmapped,
            ).
            returning(*t_workers.c)
        )
        return self.map(await self._query_one(query))

    async def find(self,
                   *,
                   state: Optional[WorkerState] = None,
                   states: Optional[Iterable[WorkerState]] = None,
                   beat_before: Optional[datetime] = None,
                   beat_after: Optional[datetime] = None,
                   limit: Optional[int] = None,
                   host: Optional[str] = None,
                   for_update: bool = False
                   ) -> AsyncIterable[Worker]:
        assert any((state, states)) and not all((state, states)), 'must be define state or states'
        query = select([t_workers])

        if state:
            query = query.where(t_workers.c.state == state)
        if states:
            query = query.where(t_workers.c.state.in_(states))
        if beat_after is not None:
            query = query.where(beat_after <= t_workers.c.heartbeat)
        if beat_before is not None:
            query = query.where(t_workers.c.heartbeat <= beat_before)
        if host is not None:
            query = query.where(t_workers.c.host == host)
        if limit is not None:
            query = query.limit(limit)
        if for_update:
            query = query.with_for_update(skip_locked=True, key_share=True)

        async for row in self._query(query):
            yield self.map(row)

    async def get(self, worker_id: str) -> Worker:
        query = (
            select([t_workers]).
            where(t_workers.c.worker_id == worker_id)
        )
        return self.map(await self._query_one(query, raise_=Worker.DoesNotExist))

    async def heartbeat(self, worker_id: str) -> str:
        query = (
            update(t_workers).
            where(t_workers.c.worker_id == worker_id).
            where(t_workers.c.state == WorkerState.RUNNING).
            values(heartbeat=func.now()).
            returning(t_workers.c.worker_id)
        )
        return (await self._query_one(query))['worker_id']

    async def save(self, obj: Worker) -> Worker:
        unmapped = self.unmap(obj)
        query = (
            update(t_workers).
            values(**unmapped).
            where(t_workers.c.worker_id == obj.worker_id).
            returning(*t_workers.c)
        )
        return self.map(await self._query_one(query))

    def delete(self, obj: Worker) -> None:
        raise NotImplementedError
