import logging
from dataclasses import dataclass
import tenacity

from .cursor_provider import create_cursor_provider
from .task import TaskParams

log = logging.getLogger(__name__)
ANALYZE_TIMEOUT = 600


@dataclass
class PgPartmanMaintenanceParams(TaskParams):
    task_name: str = 'pg_partman_maintenance'
    db_user: str = 'maildb'


@tenacity.retry(reraise=True, wait=tenacity.wait_fixed(600), stop=tenacity.stop_after_attempt(7))
async def pg_partman_run_maintenance(conn):
    async with conn.cursor() as cur:
        try:
            await cur.execute('BEGIN')
            await cur.execute("UPDATE part_config SET infinite_time_partitions=True")
            await cur.execute("SET lock_timeout to 0")
            await cur.execute("SELECT run_maintenance(NULL, FALSE, FALSE)")
            await cur.execute('COMMIT')
            log.info("Completed pg_partman_maintenance.")
        except Exception as exc:
            await cur.execute('ROLLBACK')
            log.warning('got exception during pg_partman_maintenance: %s', exc)
            raise


async def pg_partman_maintenance(conn):
    try:
        await pg_partman_run_maintenance(conn)
    except Exception:
        log.warning('pg_partman_maintenance failed')

    async with conn.cursor() as cur:
        try:
            await cur.execute("SELECT parent_table FROM part_config")
            tables = [i['parent_table'] async for i in cur]
            for table in tables:
                await cur.execute(f"ANALYZE {table}", timeout=ANALYZE_TIMEOUT)
                log.info(f"Completed ANALYZE of {table}.")
        except Exception as exc:
            log.warning('got exception during pg_partman_maintenance ANALYZE: %s', exc)


async def shard_pg_partman_maintenance(params: PgPartmanMaintenanceParams, stats):
    async with create_cursor_provider(params, stats) as conn:
        await pg_partman_maintenance(conn)
