from tractor.models import (
    ExternalProvider,
    Task,
    ExternalSecretStatus,
)
from tractor.settings import TractorDBSettings
from tractor.util.dataclasses import (
    construct_from_dict_db,
)
from psycopg2.extensions import cursor, Column
from psycopg2 import connect
from psycopg2.sql import SQL, Identifier
from datetime import timedelta
from typing import Type, Sequence


def dataobject_from_record(dataclass: Type, description: Sequence[Column], record: tuple):
    if len(description) != len(record):
        raise RuntimeError("Length mismatch between column metadata and content")
    args = {column.name: field for column, field in zip(description, record)}
    return construct_from_dict_db(dataclass, args)


class BaseDatabase:
    def __init__(self, settings: TractorDBSettings, tasks_schema: str):
        self._settings = settings
        self._tasks_schema = Identifier(tasks_schema)

    def make_connection(self):
        return connect(self._settings.conninfo)

    def set_external_secret(
        self,
        org_id: str,
        domain: str,
        provider: ExternalProvider,
        encrypted_external_secret: bytes,
        cur: cursor,
    ):
        cur.execute(
            """
                INSERT INTO tractor.external_secrets (
                    org_id,
                    domain,
                    provider,
                    encrypted_secret
                )
                VALUES (
                    %(org_id)s::text,
                    %(domain)s::text,
                    %(provider)s,
                    %(secret)s::bytea
                )
                ON CONFLICT
                    ON CONSTRAINT pk_external_secrets
                        DO UPDATE
                        SET
                            domain = EXCLUDED.domain,
                            encrypted_secret = EXCLUDED.encrypted_secret
            """,
            {
                "org_id": org_id,
                "domain": domain,
                "provider": provider.value,
                "secret": encrypted_external_secret,
            },
        )

    def get_external_secret(
        self,
        org_id: str,
        provider: ExternalProvider,
        cur: cursor,
    ) -> bytes:
        cur.execute(
            """
                SELECT
                    domain,
                    encrypted_secret
                FROM tractor.external_secrets
                WHERE
                    org_id = %(org_id)s::text
                    AND
                    provider = %(provider)s
                LIMIT 1
            """,
            {
                "org_id": org_id,
                "provider": provider.value,
            },
        )
        record = cur.fetchone()
        if record is None:
            return None
        assert isinstance(record[1], memoryview), type(record[1])
        return record[1].tobytes()

    def create_task(self, type: str, org_id: str, domain: str, worker_input: str, cur: cursor):
        cur.execute(
            SQL(
                """
                    INSERT INTO {tasks_schema}.tasks (
                        org_id,
                        domain,
                        type,
                        input
                    )
                    VALUES (
                        %(org_id)s::text,
                        %(domain)s::text,
                        %(type)s::{tasks_schema}.task_type,
                        %(input)s::jsonb
                    )
                    RETURNING
                        task_id
                """
            ).format(tasks_schema=self._tasks_schema),
            {
                "org_id": org_id,
                "domain": domain,
                "type": type,
                "input": worker_input,
            },
        )
        task_id = cur.fetchone()[0]
        if task_id is None:
            raise RuntimeError("Empty task_id from create_task")
        assert isinstance(task_id, int)
        return task_id

    def get_external_secret_status_any_provider(
        self, org_id: str, cur: cursor
    ) -> ExternalSecretStatus:
        cur.execute(
            """
                SELECT
                    provider
                FROM tractor.external_secrets
                WHERE
                    org_id = %(org_id)s::text
                LIMIT 1
            """,
            {"org_id": org_id},
        )
        record = cur.fetchone()
        if record is None:
            return ExternalSecretStatus(provider=None, external_secret_loaded=False)
        return dataobject_from_record(ExternalSecretStatus, cur.description, record)

    def acquire_task(
        self,
        type: str,
        expiry_timeout: timedelta,
        worker_id: str,
        cur: cursor,
    ) -> Task:
        cur.execute(
            SQL(
                """
                    WITH selected AS (
                        SELECT task_id
                        FROM {tasks_schema}.tasks
                        WHERE
                            type = %(type)s::{tasks_schema}.task_type
                            AND
                            worker_status = 'pending'
                            AND
                            (
                                worker_ts IS NULL
                                OR
                                NOW() - worker_ts >= %(expiry_timeout)s
                            )
                        ORDER BY created_ts
                        LIMIT 1
                        FOR UPDATE SKIP LOCKED
                    )
                    UPDATE {tasks_schema}.tasks
                    SET
                        worker_id = %(worker_id)s::text,
                        worker_ts = NOW()
                    FROM selected
                    WHERE
                        {tasks_schema}.tasks.task_id = selected.task_id
                    RETURNING *
                """
            ).format(tasks_schema=self._tasks_schema),
            {
                "type": type,
                "worker_id": worker_id,
                "expiry_timeout": expiry_timeout,
            },
        )
        record = cur.fetchone()
        if record is None:
            return None
        return dataobject_from_record(Task, cur.description, record)

    def refresh_task(self, task_id: int, worker_id: str, cur: cursor):
        cur.execute(
            SQL(
                """
                    UPDATE {tasks_schema}.tasks
                    SET
                        worker_ts = NOW()
                    WHERE
                        task_id = %(task_id)s::bigint
                        AND
                        worker_id = %(worker_id)s::text
                        AND
                        worker_status = 'pending'
                        AND
                        NOT canceled
                """
            ).format(tasks_schema=self._tasks_schema),
            {
                "task_id": task_id,
                "worker_id": worker_id,
            },
        )

    def finish_task(self, output: str, task_id: int, cur: cursor) -> None:
        cur.execute(
            SQL(
                """
                    UPDATE {tasks_schema}.tasks
                    SET
                        worker_status = 'success',
                        worker_output = %(output)s::jsonb
                    WHERE
                        task_id = %(task_id)s::bigint
                """
            ).format(tasks_schema=self._tasks_schema),
            {"output": output, "task_id": task_id},
        )

    def set_error_for_cancelled_task(self, task_id: int, cur: cursor) -> None:
        cur.execute(
            SQL(
                """
                    UPDATE {tasks_schema}.tasks
                    SET
                        worker_status = 'error',
                        worker_output = json_build_object('error', 'task is cancelled')
                    WHERE
                        task_id = %(task_id)s::bigint
                        AND
                        canceled = TRUE
                """
            ).format(tasks_schema=self._tasks_schema),
            {"task_id": task_id},
        )

    def fail_task(self, output: str, task_id: int, cur: cursor) -> None:
        cur.execute(
            SQL(
                """
                    UPDATE {tasks_schema}.tasks
                    SET
                        worker_status = 'error',
                        worker_output = %(output)s::jsonb
                    WHERE
                        task_id = %(task_id)s::bigint
                """
            ).format(tasks_schema=self._tasks_schema),
            {"output": output, "task_id": task_id},
        )

    def get_task_by_task_id(self, task_id: int, cur: cursor) -> Task:
        cur.execute(
            SQL(
                """
                    SELECT *
                    FROM {tasks_schema}.tasks
                    WHERE
                        task_id = %(task_id)s::bigint
                """
            ).format(tasks_schema=self._tasks_schema),
            {"task_id": task_id},
        )
        record = cur.fetchone()
        if record is None:
            return None
        return dataobject_from_record(Task, cur.description, record)
