from typing import Type

import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB

from sendr_taskqueue.worker.base.entites import BaseTaskType, BaseWorkerType
from sendr_taskqueue.worker.storage.db.entities import TaskState, WorkerState
from sendr_utils import enum_values


def get_task_state(metadata: sa.MetaData) -> sa.Enum:
    return sa.Enum(TaskState, name='task_type', metadata=metadata, values_callable=enum_values)


def get_task_type(metadata: sa.MetaData, cls: Type[BaseTaskType]) -> sa.Enum:
    return sa.Enum(cls, name='task_type', metadata=metadata, values_callable=enum_values)


def get_worker_state_type(metadata: sa.MetaData) -> sa.Enum:
    return sa.Enum(WorkerState, name='worker_state', metadata=metadata, values_callable=enum_values)


def get_worker_type(metadata: sa.MetaData, cls: Type[BaseWorkerType]) -> sa.Enum:
    return sa.Enum(cls, name='worker_type', metadata=metadata, values_callable=enum_values)


def get_tasks_table(metadata: sa.MetaData, task_type_cls: Type[BaseTaskType]) -> sa.Table:
    return sa.Table(
        'tasks', metadata,
        sa.Column('task_id', sa.BigInteger(), primary_key=True, nullable=False),
        sa.Column('task_type', get_task_type(metadata, task_type_cls), nullable=False),
        sa.Column('state', get_task_state(metadata), nullable=False),
        sa.Column('action_name', sa.Text()),
        sa.Column('params', JSONB()),
        sa.Column('details', JSONB()),
        sa.Column('retries', sa.Integer(), nullable=False, default=0),
        sa.Column('run_at', sa.DateTime(timezone=True), nullable=False),
        sa.Column('created', sa.DateTime(timezone=True), nullable=False),
        sa.Column('updated', sa.DateTime(timezone=True), nullable=False),
    )


def get_workers_table(metadata: sa.MetaData, worker_type_cls: Type[BaseWorkerType]) -> sa.Table:
    return sa.Table(
        'workers', metadata,
        sa.Column('worker_id', sa.String(32), nullable=False),
        sa.Column('worker_type', get_worker_type(metadata, worker_type_cls), nullable=False),
        sa.Column('host', sa.String(), nullable=False),
        sa.Column('state', get_worker_state_type(metadata), nullable=False),
        sa.Column('heartbeat', sa.DateTime(timezone=True)),
        sa.Column('startup', sa.DateTime(timezone=True)),
        sa.Column('task_id', sa.BigInteger()),
        sa.PrimaryKeyConstraint('worker_id'),
        sa.ForeignKeyConstraint(('task_id',), ('task.id',)),
    )
