import os
import errno
import psycopg2
import re
import json

from psycopg2.extras import DictCursor
from dataclasses import dataclass
from collections import defaultdict

from mail.tools.dbaas.helpers.constants import DB_NAME
from mail.tools.dbaas.helpers.yc_client import ResourceKinds
from mail.tools.dbaas.helpers.types.env import Env


def get_shard_id(sharpei_dsn: str, shard_name: str):
    with psycopg2.connect(sharpei_dsn) as conn:
        cur = conn.cursor()
        cur.execute(
            'select shard_id from shards.shards where name = %(shard_name)s',
            vars=dict(
                shard_name=shard_name,
            )
        )
        shard_id = cur.fetchone()[0]
        return shard_id


def get_users(sharpei_dsn: str, shard_id: int):
    with psycopg2.connect(sharpei_dsn) as conn:
        cur = conn.cursor()
        cur.execute(
            '''
            select uid from shards.users where shard_id = %(shard_id)s
            ''',
            vars=dict(
                shard_id=shard_id,
            )
        )
        users = [row[0] for row in cur.fetchall()]
        print(f'Found {len(users)} users in sharddb')

        cur.execute(
            '''
            select uid from shards.deleted_users where shard_id = %(shard_id)s
            ''',
            vars=dict(
                shard_id=shard_id,
            )
        )
        deleted_users = [row[0] for row in cur.fetchall()]
        print(f'Found {len(deleted_users)} deleted users in sharddb')

        return set(users), set(deleted_users)


def get_shard_users(shard_dsn):
    shard_users = defaultdict(set)
    with psycopg2.connect(dsn=shard_dsn, cursor_factory=DictCursor) as shard_conn:
        cur = shard_conn.cursor()
        cur.execute(
            '''
            SELECT uid, is_here, is_deleted
              FROM mail.users
            '''
        )
        for user in cur.fetchall():
            shard_users[UserPresence(user['is_here'], user['is_deleted'])].add(user['uid'])
    return shard_users


def report_users(shard_name: str, users, tag: str):
    print(f'Users marked as "{tag}": {len(users)}')
    if len(users) <= 20:
        for u in users:
            print(u)
    if users:
        mkdir_p(shard_name)
        with open(f'{shard_name}/{tag}.csv', 'w') as fd:
            fd.writelines(f'{uid}\n' for uid in users)


def mkdir_p(path, **kwargs):
    try:
        os.makedirs(path, **kwargs)
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise


@dataclass(frozen=True)
class UserPresence:
    is_here: bool
    is_deleted: bool


ADD_CLUSTER_QUERY = '''
    select code.add_cluster(
        %(shard_name)s,
        %(load_type)s,
        array_agg(instances)
    ) from json_populate_recordset(
        null::code.instance,
        %(shard_instances)s
    ) instances
'''


SHARD_NAME_RE = re.compile(f'{DB_NAME}_(?P<shard_name>.*?)_pgcluster')


def add_shard_to_sharddb(env, shard_name, hosts, load_type='dbaas_hot'):
    with psycopg2.connect(env.sharpei_dsn) as conn:
        cur = conn.cursor()
        cur.execute(ADD_CLUSTER_QUERY, vars=dict(
            shard_name=shard_name,
            load_type=load_type,
            shard_instances=json.dumps([
                {
                    'host': host['name'],
                    'dc': host['zone_id'],
                    'port': 6432,
                    'dbname': DB_NAME,
                }
                for host in hosts
            ])
        ))
        shard_id = cur.fetchone()[0]
        return shard_id


MAX_WORKERS_EXCEPTIONS = {
    ResourceKinds.moscow.value.description: 6,
}


SET_SHARD_WORKLOAD_EXCEPTION_QUERY = '''
    INSERT INTO buckets.workload
    (shard_id, max_workers)
    VALUES (%(shard_id)s, %(max_workers)s)
    ON CONFLICT DO UPDATE
        SET max_workers = %(max_workers)s
        WHERE shard_id = %(shard_id)s;
'''


def set_shard_workload_exception(env: Env, shard_id: int, max_workers: int) -> None:
    with psycopg2.connect(env.sharpei_dsn) as conn:
        cur = conn.cursor()
        cur.execute(SET_SHARD_WORKLOAD_EXCEPTION_QUERY, vars={'shard_id': shard_id, 'max_workers': max_workers})


def get_shard_master_dsn(sharpei_dsn: str, shard_id: int, user: str, passwd: str):
    with psycopg2.connect(sharpei_dsn) as conn:
        cur = conn.cursor()
        cur.execute(
            'select host from shards.instances where shard_id = %(shard_id)s',
            vars=dict(shard_id=shard_id),
        )
        hosts = [row[0] for row in cur.fetchall()]

    return ' '.join((
        f'host={",".join(hosts)}',
        'port=6432',
        'dbname=maildb',
        f'user={user}',
        f'password={passwd}',
        'target_session_attrs=read-write',
    ))


def get_shard_migration_version(shard_dsn):
    with psycopg2.connect(dsn=shard_dsn) as shard_conn:
        cur = shard_conn.cursor()
        cur.execute(
            '''
            SELECT version FROM public.schema_version ORDER BY version desc limit 1
            '''
        )
        version = cur.fetchone()[0]
    return version
