from psycopg2.sql import SQL
from psycopg2.extensions import cursor
import json
from typing import List
from tractor.disk.models import (
    UserMigration,
    UserMigrationStatus,
    UserMigrationInfo,
    UserMigrationExportInfo,
    UserStatistics,
)
from tractor.models import Task
from tractor.error import SYNC_ERROR
from tractor.settings import TractorDBSettings
from tractor.util.dataclasses import (
    construct_from_dict_db,
    transform_into_json_str_without_none,
)
from tractor.db import BaseDatabase, dataobject_from_record


class Database(BaseDatabase):
    def __init__(self, settings: TractorDBSettings):
        super().__init__(settings, "tractor_disk")

    def create_user_migration(
        self, org_id: str, domain: str, login: str, list_task_id: int, cur: cursor
    ):
        cur.execute(
            """
                INSERT INTO tractor_disk.user_migrations (
                    org_id,
                    domain,
                    login,
                    status,
                    list_task_id
                )
                VALUES (
                    %(org_id)s::text,
                    %(domain)s::text,
                    %(login)s::text,
                    %(status)s::tractor_disk.user_migration_status,
                    %(list_task_id)s::bigint
                )
            """,
            {
                "org_id": org_id,
                "domain": domain,
                "login": login,
                "status": "listing",
                "list_task_id": list_task_id,
            },
        )

    def reset_user_migration(
        self,
        org_id: str,
        login: str,
        domain: str,
        list_task_id: int,
        cur: cursor,
    ):
        cur.execute(
            """
            UPDATE tractor_disk.user_migrations
            SET
                status = 'listing'::tractor_disk.user_migration_status,
                error_reason = '',
                list_task_id = %(list_task_id)s::bigint,
                sync_task_ids = array[]::bigint[],
                stats = '{}'::jsonb,
                domain = %(domain)s::text
            WHERE
                org_id = %(org_id)s::text
                AND
                login = %(login)s::text
            """,
            {
                "org_id": org_id,
                "login": login,
                "list_task_id": list_task_id,
                "domain": domain,
            },
        )

    def mark_migration_finished(
        self,
        org_id: str,
        login: str,
        status: str,
        error_reason: str,
        stats: UserStatistics,
        cur: cursor,
    ):
        cur.execute(
            """
            UPDATE tractor_disk.user_migrations
            SET
                status = %(status)s::tractor_disk.user_migration_status,
                error_reason = %(error_reason)s::text,
                stats = %(stats)s::jsonb
            WHERE
                org_id = %(org_id)s::text
                AND
                login = %(login)s::text
            """,
            {
                "org_id": org_id,
                "login": login,
                "status": status,
                "error_reason": error_reason,
                "stats": transform_into_json_str_without_none(stats),
            },
        )

    def update_migration_to_sync_state(
        self, org_id: str, login: str, sync_task_ids: List[int], stats: UserStatistics, cur: cursor
    ):
        cur.execute(
            """
                UPDATE tractor_disk.user_migrations
                SET
                    status = %(status)s::tractor_disk.user_migration_status,
                    error_reason = %(error_reason)s::text,
                    sync_task_ids = %(sync_task_ids)s::bigint[],
                    stats = %(stats)s::jsonb
                WHERE
                    org_id = %(org_id)s::text
                    AND
                    login = %(login)s::text
            """,
            {
                "org_id": org_id,
                "login": login,
                "status": "syncing",
                "error_reason": "",
                "sync_task_ids": sync_task_ids,
                "stats": transform_into_json_str_without_none(stats),
            },
        )

    def get_user_migration(self, org_id: str, login: str, cur: cursor):
        cur.execute(
            """
                SELECT
                    domain,
                    status,
                    error_reason,
                    list_task_id,
                    sync_task_ids,
                    stats
                FROM tractor_disk.user_migrations
                WHERE
                    org_id = %(org_id)s::text
                    AND
                    login = %(login)s::text
                LIMIT 1
            """,
            {"org_id": org_id, "login": login},
        )
        resp = cur.fetchone()
        if resp is None:
            return None
        stats = construct_from_dict_db(UserStatistics, resp[5])
        user_migration = UserMigration(
            org_id=org_id,
            domain=resp[0],
            login=login,
            status=UserMigrationStatus(resp[1]),
            error_reason=resp[2],
            list_task_id=resp[3],
            sync_task_ids=resp[4],
            stats=stats,
        )
        return user_migration

    def finish_sync_task(self, output: str, task_id: int, cur: cursor) -> None:
        return self.finish_task(output, task_id, cur)

    def fail_listing_task(self, error: str, task_id: int, cur: cursor, **kwargs) -> None:
        return self.fail_task(json.dumps({"error": error, **kwargs}), task_id, cur)

    def fail_sync_task(self, error: str, task_id: int, cur: cursor) -> None:
        return self.fail_task(json.dumps({"error": error}), task_id, cur)

    def get_user_migration_statuses(self, org_id: str, cur) -> List[UserMigrationInfo]:
        cur.execute(
            """
                SELECT
                    login,
                    status,
                    error_reason
                FROM tractor_disk.user_migrations
                WHERE
                    org_id = %(org_id)s::text
            """,
            {"org_id": org_id},
        )
        resp = cur.fetchall()
        res = []
        for record in resp:
            res.append(
                UserMigrationInfo(
                    login=record[0], status=UserMigrationStatus(record[1]), error_reason=record[2]
                )
            )
        return res

    def get_user_migrations_for_export(
        self, org_id: str, cur: cursor
    ) -> List[UserMigrationExportInfo]:
        cur.execute(
            """
            SELECT login, status, error_reason, stats
            FROM tractor_disk.user_migrations
            WHERE org_id = %(org_id)s::text
        """,
            {
                "org_id": org_id,
            },
        )
        resp = cur.fetchall()
        result = len(resp) * [None]
        for i, record in enumerate(resp):
            stats = construct_from_dict_db(UserStatistics, record[3])
            result[i] = UserMigrationExportInfo(
                login=record[0],
                status=UserMigrationStatus(record[1]),
                error_reason=record[2],
                stats=stats,
            )
        return result

    def cancel_migration(self, org_id: str, cur: cursor):
        cur.execute(
            SQL(
                """
                    BEGIN;
                    WITH migrations_to_cancel AS (
                        WITH migration_ids AS (
                            SELECT
                                org_id,
                                login
                            FROM tractor_disk.user_migrations
                            WHERE
                                org_id=%(org_id)s::text
                                AND
                                (
                                    status='listing'
                                    OR
                                    status='syncing'
                                )
                            FOR UPDATE
                        )
                        UPDATE tractor_disk.user_migrations
                        SET
                            status='canceling'
                        WHERE
                            (org_id, login) IN (
                                SELECT
                                    org_id,
                                    login
                                FROM
                                migration_ids
                            )
                        RETURNING
                            user_migrations.sync_task_ids,
                            user_migrations.list_task_id
                    )
                    UPDATE {tasks_schema}.tasks
                    SET
                        canceled=True
                    WHERE
                        task_id IN (
                            SELECT
                                UNNEST(migrations_to_cancel.sync_task_ids || array[migrations_to_cancel.list_task_id])
                            FROM migrations_to_cancel
                        )
                        AND
                        canceled=False
                        AND
                        worker_status='pending'
                        ;
                    COMMIT;
                """
            ).format(tasks_schema=self._tasks_schema),
            {"org_id": org_id},
        )

    def poll_listing_migration(self, cur: cursor):
        cur.execute(
            SQL(
                """
                    SELECT *
                    FROM
                        tractor_disk.user_migrations AS m
                        INNER JOIN {tasks_schema}.tasks AS t
                        ON m.list_task_id = t.task_id
                    WHERE
                        (
                            m.status = 'canceling'
                            OR
                            m.status = 'listing'
                        )
                        AND
                        (t.worker_status != 'pending')
                    LIMIT 1
                    FOR UPDATE SKIP LOCKED
                """
            ).format(tasks_schema=self._tasks_schema),
            {},
        )
        record = cur.fetchone()
        if record is None:
            return None, None
        migration = dataobject_from_record(UserMigration, cur.description, record)
        task = dataobject_from_record(Task, cur.description, record)
        return migration, task

    def process_finished_sync_migrations(self, cur: cursor, limit: int = 100):
        cur.execute(
            SQL(
                """
                    BEGIN;
                    WITH migrations_to_update AS (
                        SELECT
                            m.org_id,
                            m.login,
                            COALESCE(cardinality(m.sync_task_ids), 0) AS total,
                            count(*) count,
                            bool_or(t.worker_status = 'error') AS error
                        FROM
                            tractor_disk.user_migrations AS m
                            INNER JOIN
                            {tasks_schema}.tasks AS t
                            ON
                            t.task_id=ANY(m.sync_task_ids)
                        WHERE
                            m.status = 'syncing'
                            AND
                            cardinality(m.sync_task_ids) > 0
                            AND
                            t.worker_status <> 'pending'
                        GROUP BY
                            m.org_id,
                            m.login
                        ORDER BY
                            count DESC
                        LIMIT %(limit)s
                    ),
                    ready_migrations AS (
                        SELECT
                            m.org_id,
                            m.login,
                            u.error
                        FROM
                            tractor_disk.user_migrations AS m
                            JOIN
                            migrations_to_update AS u
                            ON
                            m.org_id = u.org_id AND m.login = u.login
                        WHERE
                            u.count >= u.total
                        FOR UPDATE SKIP LOCKED
                    )

                    UPDATE
                        tractor_disk.user_migrations AS m
                    SET
                        status=CASE WHEN r.error THEN
                            'error'::tractor_disk.user_migration_status ELSE
                            'success'::tractor_disk.user_migration_status END,
                        error_reason=CASE WHEN r.error THEN
                            %(error_reason)s::text ELSE
                            '' END
                    FROM ready_migrations AS r
                    WHERE
                        m.org_id = r.org_id
                        AND
                        m.login = r.login;

                    COMMIT;
            """
            ).format(tasks_schema=self._tasks_schema),
            {
                "error_reason": SYNC_ERROR,
                "limit": limit,
            },
        )
