import re
import subprocess
import sys
from itertools import chain
from typing import List, Dict, Sequence, Any

import psutil
import yaml

from mail.tools.dbaas.helpers.types.user import User, UserParamKeyPrefix
from mail.tools.dbaas.helpers.constants import DB_NAME

from dataclasses import dataclass
from enum import Enum, unique


@unique
class PoolingMode(Enum):
    Tx = 'TRANSACTION'
    Session = 'SESSION'


@dataclass
class ClusterSettings:
    pg: Dict[str, str]
    pooling_mode: PoolingMode = PoolingMode.Tx


@dataclass
class Zone:
    dc: str
    prio: int = 0

    def as_cli_param(self) -> str:
        return f'zone-id={self.dc},priority={self.prio}'


@dataclass(frozen=True)
class ResourceKind:
    description: str
    preset: str
    disk_size_gb: int
    hosts: Sequence[Zone]

    def as_cli_param(self) -> List[str]:
        hosts = chain.from_iterable(('--host', host.as_cli_param()) for host in self.hosts)
        return [
            '--description', self.description,
            '--resource-preset', self.preset,
            '--disk-size', str(self.disk_size_gb),
        ] + list(hosts)


@unique
class ResourceKinds(Enum):
    common = ResourceKind("MDB prod shard", "s2.xlarge", 2048, (Zone('vla', 10), Zone('sas', 10), Zone('man', 0)))
    common2 = ResourceKind("MDB prod shard", "db1.xlarge", 2048, (Zone('vla', 10), Zone('sas', 10), Zone('man', 0)))
    cold = ResourceKind("MDB prod shard with 3 TB SSD", "s2.xlarge", 3072, (Zone('vla', 10), Zone('sas', 10), Zone('man', 0)))
    cold_moscow2 = ResourceKind("MDB prod shard iva-myt-sas with 2 TB SSD", "db1.xlarge", 2300, (Zone('iva', 10), Zone('myt', 10), Zone('sas', 0)))
    cold_moscow3 = ResourceKind("MDB prod shard iva-myt-sas with 3 TB SSD", "db1.xlarge", 3200, (Zone('iva', 10), Zone('myt', 10), Zone('sas', 0)))
    cold_moscow4 = ResourceKind("MDB prod shard iva-myt-sas with 4 TB SSD", "db1.xlarge", 4096, (Zone('iva', 10), Zone('myt', 10), Zone('sas', 0)))
    mosman = ResourceKind("MDB prod shard myt-sas-man", "db1.xlarge", 1650, (Zone('sas', 10), Zone('myt', 10), Zone('man', 0)))
    moscow = ResourceKind("MDB prod shard iva-myt-sas", "db1.xlarge", 1680, (Zone('iva', 10), Zone('myt', 10), Zone('sas', 0)))
    moscow_s2 = ResourceKind("MDB prod shard myt-sas-man", "s2.xlarge", 2048, (Zone('sas', 10), Zone('myt', 5), Zone('man', 0)))
    common_s3 = ResourceKind("MDB prod shard vla-sas-man", "s3.xlarge", 2048, (Zone('vla', 10), Zone('sas', 10), Zone('man', 0)))
    big_s3 = ResourceKind("MDB prod shard vla-sas-man with 3 TB SSD", "s3.xlarge", 3072, (Zone('vla', 10), Zone('sas', 10), Zone('man', 0)))


class YcClient:
    CLUSTER_NAME_RE = re.compile(fr'{DB_NAME}_(?P<env>\w+?)(-(?P<suffix>\w+?))?(?P<num>\d+)_pgcluster')
    DISK_SIZE_GB = 2048

    def __init__(self, cloud_id: str, folder_name: str, owner_name: str = None, owner_passwd: str = None):
        self.cloud_id = cloud_id
        self.folder_name = folder_name
        self.owner_name = owner_name
        self.owner_passwd = owner_passwd

    @staticmethod
    def cmd_prefix():
        return ['yc', 'managed-postgresql']

    def cmd_common_args(self):
        return [
            '--cloud-id', self.cloud_id,
            '--folder-name', self.folder_name,
        ]

    def find_next_cluster_name(self, env: str = 'prod', suffix: str = None):
        print('Looking up for next vacant cluster name...')
        clusters = self.get_clusters()
        cluster_num = max(
            (
                int(match.group('num')) for match in
                (
                    self.CLUSTER_NAME_RE.match(cluster['name'])
                    for cluster in clusters
                )
                if match and match.group('env') == env and match.group('suffix') == suffix
            ),
            default=0
        ) + 1
        print(f'Found vacant cluster num {cluster_num:03} for suffix {suffix}')
        shard_name = env
        if suffix:
            shard_name += f'-{suffix}'
        shard_name += f'{cluster_num:03}'
        return f'{DB_NAME}_{shard_name}_pgcluster'

    def get_clusters(self, env: str = None) -> List[Dict[str, Any]]:
        cmd = self.cmd_prefix() + [
            'cluster', 'list',
            '--limit', '1000',
            '--format', 'yaml',
        ] + self.cmd_common_args()
        with psutil.Popen(cmd, stdout=subprocess.PIPE, stderr=sys.stderr) as proc:
            clusters = yaml.safe_load(proc.stdout)
        clusters = [
            {
                'name_info': match.groupdict(),
                **cluster
            } for cluster, match in
            ((cluster, self.CLUSTER_NAME_RE.match(cluster['name'])) for cluster in clusters)
            if match and (env is None or match.group('env') == env)
        ]
        return clusters

    def create_cluster(self, cluster_name: str, mdb_env: str, resource_kind: ResourceKind, users: List[User]):
        existing_clusters = self.get_clusters()
        if cluster_name in [cluster['name'] for cluster in existing_clusters]:
            print(f'Cluster {cluster_name} is already created')
            return
        print(f'Creating cluster {cluster_name}...')
        cmd = (
            self.cmd_prefix()
            + [
                'cluster', 'create',
                '--name', cluster_name,
                '--environment', mdb_env,
                '--network-id', " ",
                '--disk-type', 'local-ssd',
                '--postgresql-version', '11',
                '--database', f'name={DB_NAME},owner={self.owner_name}',
            ]
            + resource_kind.as_cli_param()
            + list(chain.from_iterable(
                ('--user', ','.join(user.as_cli_param(key_prefix=UserParamKeyPrefix.ForCreateCluster)))
                for user in users
            ))
            + self.cmd_common_args()
        )
        proc = psutil.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr)
        retcode = proc.wait()
        if retcode:
            raise RuntimeError(f'Cluster creation finished with retcode={retcode}')

    def create_users(self, cluster_name: str, users: List[User]):
        print('Looking up for already created users...')
        existing_users = set(user['name'] for user in self.get_users(cluster_name))
        users_to_create = [
            u for u in users if u.name not in existing_users
        ]
        print(f'Users already created: {existing_users}')
        print(f'Users to be created: {[us.name for us in users_to_create]}')
        for user in users_to_create:
            self.create_user(cluster_name=cluster_name, user=user)

    def get_users(self, cluster_name: str):
        cmd = self.cmd_prefix() + [
            'user', 'list',
            '--cluster-name', cluster_name,
            '--limit', '0',
            '--format', 'yaml',
        ] + self.cmd_common_args()
        with psutil.Popen(cmd, stdout=subprocess.PIPE, stderr=sys.stderr) as proc:
            return yaml.safe_load(proc.stdout)

    def create_user(self, cluster_name, user: User):
        print(f'Creating user {user}...')
        cmd = self.cmd_prefix() + [
            'user', 'create', user.name,
            '--cluster-name', cluster_name,
        ] + user.as_cli_param(key_prefix=UserParamKeyPrefix.ForCreateUser) + self.cmd_common_args()
        proc = psutil.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr)
        retcode = proc.wait()
        if retcode:
            raise RuntimeError(f'User creation finished with retcode={retcode}')

    def update_users(self, cluster_name, users: Sequence[User]):
        current_users = {
            user.name: user for user in (
                User.from_yc(**user_data) for user_data in self.get_users(cluster_name)
            )
        }
        for user in users:
            current_user = current_users.get(user.name)
            if current_user is None:
                print('User {user.name} not found in YC')
                self.create_user(cluster_name, user)
            if current_user != user:
                print('Current user is not equal to proposed settings')
                print(f'Current user: {current_user}')
                print(f'Proposed user: {user}')
                diff = {
                    lkey: f'Current: {lval} vs proposed: {rval}'
                    for ((lkey, lval), (rkey, rval))
                    in zip(
                        sorted(current_user.as_items()),
                        sorted(user.as_items()),
                    )
                    if lkey == rkey and lval != rval
                }
                print('Diff: ', diff)
                self.update_user(cluster_name=cluster_name, user=user)

    def update_user(self, cluster_name, user: User):
        print(f'Updating user {user}...')
        cmd = self.cmd_prefix() + [
            'user', 'update', user.name,
            '--cluster-name', cluster_name,
        ] + user.as_cli_param(key_prefix=UserParamKeyPrefix.ForCreateUser) + self.cmd_common_args()
        proc = psutil.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr)
        retcode = proc.wait()
        if retcode:
            raise RuntimeError(f'User creation finished with retcode={retcode}')

    def get_extensions(self, cluster_name: str):
        cmd = self.cmd_prefix() + [
            'database', 'get', DB_NAME,
            '--cluster-name', cluster_name,
            '--format', 'yaml',
        ] + self.cmd_common_args()
        with psutil.Popen(cmd, stdout=subprocess.PIPE, stderr=sys.stderr) as proc:
            return yaml.safe_load(proc.stdout).get('extensions', [])

    def create_extensions(self, cluster_name: str, extensions: Dict[str, str]):
        print('Looking up for already created extensions...')
        existing_extensions = set(ext['name'] for ext in self.get_extensions(cluster_name))
        exts_to_create = {
            ext: version for ext, version in extensions.items() if ext not in existing_extensions
        }
        print(f'Extensions to be created: {exts_to_create}')

        if exts_to_create:
            print(f'Adding extensions {exts_to_create} for {cluster_name}...')
            cmd = self.cmd_prefix() + [
                'database', 'update', DB_NAME,
                '--cluster-name', cluster_name,
                # Full list of extensions because skipped extensions will be deleted
                '--extensions', ','.join((f'{ext}={version}' for ext, version in extensions.items())),
            ] + self.cmd_common_args()
            proc = psutil.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr)
            retcode = proc.wait()
            if retcode:
                raise RuntimeError(f'Adding extensions finished with retcode={retcode}')

    def get_settings(self, cluster_name: str) -> ClusterSettings:
        cmd = self.cmd_prefix() + [
            'cluster', 'get',
            '--name', cluster_name,
        ] + self.cmd_common_args()
        with psutil.Popen(cmd, stdout=subprocess.PIPE, stderr=sys.stderr) as proc:
            config_data = yaml.safe_load(proc.stdout)['config']
            version = config_data['version']
            pg_config = config_data.get(f'postgresql_config_{version}')
            try:
                pooling_mode = PoolingMode(config_data['pooler_config']['pooling_mode'])
            except KeyError:
                pooling_mode = None
            return ClusterSettings(
                pg=pg_config['user_config'],
                pooling_mode=pooling_mode,
            )

    def update_settings(self, cluster_name: str, pooling_mode: PoolingMode = None, **config_options):
        print('Looking up for current cluster settings...')
        current_settings = self.get_settings(cluster_name)
        if (
            pooling_mode == PoolingMode.Session and current_settings.pooling_mode != PoolingMode.Session
            or pooling_mode == PoolingMode.Tx and current_settings.pooling_mode == PoolingMode.Session
        ):
            print(f'Setting pooling mode to {pooling_mode} for {cluster_name}...')
            cmd = self.cmd_prefix() + [
                'cluster', 'update',
                '--name', cluster_name,
                '--connection-pooling-mode', pooling_mode,
            ] + self.cmd_common_args()
            proc = psutil.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr)
            retcode = proc.wait()
            if retcode:
                raise RuntimeError(f'Settings update finished with retcode={retcode}')

        settings_to_update = {
            k: v for k, v in config_options.items()
            if str(v) != current_settings.pg.get(k)
        }
        if settings_to_update:
            print(f'Apply settings {config_options} for {cluster_name}...')
            cmd = self.cmd_prefix() + [
                'cluster', 'update-config',
                '--name', cluster_name,
                '--set', ','.join(f'{key}={value}' for key, value in settings_to_update.items())
            ] + self.cmd_common_args()
            proc = psutil.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr)
            retcode = proc.wait()
            if retcode:
                raise RuntimeError(f'Settings update finished with retcode={retcode}')

    def get_hosts(self, cluster_name):
        cmd = self.cmd_prefix() + [
            'hosts', 'list',
            '--cluster-name', cluster_name,
            '--limit', '0',
            '--format', 'yaml',
        ] + self.cmd_common_args()
        with psutil.Popen(cmd, stdout=subprocess.PIPE, stderr=sys.stderr) as proc:
            return yaml.safe_load(proc.stdout)
