from typing import Any, Dict, List

import psycopg2

from django.conf import settings
from django.db import router as dj_router, transaction, connection
from django.http import JsonResponse
from django_replicated.router import ReplicationRouter

from staff.lib.json import JSONEncoder


def get_db_statuses() -> Dict:
    databases = getattr(settings, 'DATABASES', None)

    if not databases:
        return {}

    replication_router = None

    for router in dj_router.routers:
        if isinstance(router, ReplicationRouter):
            replication_router = router

    result = {}

    if replication_router:

        for db_key, attrs in databases.items():

            result[db_key] = {
                'host': attrs['HOST'],
                'name': attrs['NAME'],
                'is_alive': replication_router.is_alive(db_key),
            }

    return result


def filterout_alive_databases(db_statuses: Dict) -> Dict:
    return {
        db: status
        for db, status in db_statuses.items()
        if not status['is_alive']
    }


def check_db(request):
    show_all = bool(request.GET.get('show_all', False))
    db_statuses = get_db_statuses()

    if show_all:
        return JsonResponse(data=db_statuses)

    return JsonResponse(data=filterout_alive_databases(db_statuses))


def get_stale_transactions() -> List[Dict[str, Any]]:
    sql = """
    WITH stale_activity AS (
        SELECT application_name, datname, pid, state, query, age(clock_timestamp(), query_start) AS age
        FROM pg_stat_activity
        WHERE state <> 'idle'
        AND query NOT LIKE '% FROM pg_stat_activity %' AND query NOT LIKE 'START_REPLICATION%')
    SELECT * FROM stale_activity
    WHERE age > interval '1 hour'
    ORDER BY age;
    """
    connection = transaction.get_connection()
    cursor = connection.cursor()
    cursor.execute(sql)
    fields = ('appname', 'db', 'pid', 'state', 'query', 'age')
    result = [dict(zip(fields, item)) for item in cursor.fetchall()]
    for item in result:
        item['age'] = item['age'].total_seconds()

    return result


def _get_pg_stat_activity(connection: Any) -> List[Dict[str, Any]]:
    fields = (
        'username',
        'database',
        'application_name',
        'pid',
        'state',
        'query',
        'transaction_age',
        'last_query_age',
        'backend_age',
    )
    cursor = connection.cursor()
    query = (
        'SELECT usename, datname, application_name, pid, state, query, '
        'age(clock_timestamp(), xact_start) AS tran_age, '
        'age(clock_timestamp(), query_start) AS age, '
        'age(clock_timestamp(), backend_start) AS backend_age '
        'FROM pg_stat_activity '
        'ORDER BY tran_age DESC NULLS LAST'
    )
    cursor.execute(query)
    return [dict(zip(fields, row)) for row in cursor.fetchall()]


def _get_pg_is_master(connection: Any) -> bool:
    cursor = connection.cursor()
    cursor.execute('SELECT pg_is_in_recovery()')
    return not cursor.fetchone()[0]


def _get_pg_locks(connection: Any) -> List[Dict[str, Any]]:
    fields = ('locktype', 'database', 'relation_name', 'pid', 'query', 'username', 'mode', 'granted', 'fastpath')
    cursor = connection.cursor()
    query = (
        'SELECT locktype, pd.datname, pc.relname, pg_locks.pid, psa.query, psa.usename, mode, granted, fastpath '
        'FROM pg_locks '
        'LEFT JOIN pg_stat_activity psa on pg_locks.pid = psa.pid '
        'LEFT JOIN pg_database pd on pg_locks.database = pd.oid '
        'LEFT JOIN pg_class pc on pg_locks.relation = pc.oid'
    )
    cursor.execute(query)
    return [dict(zip(fields, row)) for row in cursor.fetchall()]


def get_pg_hosts_stat() -> Dict[str, Dict[str, Any]]:
    default_db = settings.DATABASES['default']
    hosts = default_db['HOST'].split(',')
    result = {}
    for host in hosts:
        result[host] = {}
        try:
            connection = psycopg2.connect(
                user=default_db['USER'],
                password=default_db['PASSWORD'],
                host=host,
                port=default_db['PORT'],
            )
        except Exception:
            result[host]['error'] = 'Can\'t connect'
            continue

        try:
            connection.set_session(readonly=True, autocommit=True)
            result[host]['queries'] = _get_pg_stat_activity(connection)
            result[host]['is_master'] = _get_pg_is_master(connection)
            result[host]['locks'] = _get_pg_locks(connection)
        finally:
            connection.close()

    return result


def pg_top(request):
    result = get_pg_hosts_stat()
    return JsonResponse(result, encoder=JSONEncoder)


def get_dead_tuples_top() -> Dict[str, int]:
    query = f"""
        SELECT relname, n_dead_tup AS "dead_tuples" FROM pg_stat_user_tables
        WHERE schemaname = 'public' AND n_dead_tup > 1000
        ORDER BY n_dead_tup DESC;
    """
    with connection.cursor() as cursor:
        cursor.execute(query)
        return {row[0]: row[1] for row in cursor.fetchall()}
