import logging
from logging.handlers import QueueHandler
import os
from functools import partial
from multiprocessing import Lock, Queue
from multiprocessing.pool import Pool
import threading

import click
from tqdm import tqdm

import yaml
from mail.tools.dbaas.helpers.constants import EXTENSIONS, POSTGRE_CONFIG_OPTS
from mail.tools.dbaas.helpers.pgmigrate import pgmigrate
from mail.tools.dbaas.helpers.types.env import Envs, Env
from mail.tools.dbaas.helpers.types.user import User
from mail.tools.dbaas.helpers.yc_client import YcClient
from mail.tools.dbaas.helpers.yav import get_users_from_yav
from mail.tools.dbaas.helpers.infra import create_event, finish_event
from mail.tools.dbaas.bin.migrate_shards.utils import (
    split, BufferingLogger, TqdmLoggingHandler,
    TqdmProgressHandler, update_progress,
    green
)

log = logging.getLogger(__name__)

OWNER_NAME = 'maildb'

lock = Lock()

progress_bar = None


def setup_progress_bar(total):
    global progress_bar
    progress_bar = tqdm(total=total, desc='Migration progress', unit='cluster')
    log.addHandler(TqdmProgressHandler(progress_bar, logging.DEBUG))


def wait_for_enter(msg):
    progress_bar.clear()
    input(msg)
    progress_bar.refresh()


def setup_logger():
    log.setLevel(logging.DEBUG)
    log.addHandler(TqdmLoggingHandler(logging.DEBUG))


def log_consumer_thread(q):
    while True:
        record = q.get()
        if record is None:
            break
        log.handle(record)


def start_log_consumer(maxsize):
    log_queue = Queue(maxsize)
    consumer = threading.Thread(target=log_consumer_thread, args=(log_queue,))
    consumer.start()
    return log_queue, consumer


def stop_log_consumer(log_queue, consumer):
    log_queue.put(None)
    consumer.join()


def setup_worker(q, l):
    if not log.hasHandlers():
        log.setLevel(logging.DEBUG)
        log.addHandler(QueueHandler(q))
    global lock
    lock = l


def check_positive(ctx, param, value):
    if value <= 0:
        raise click.BadParameter("should be positive")
    return value


@click.command('migrate-shards')
@click.option('--env', 'env_name', default=Envs.prod.value.name)
@click.option('--migrations-dir', required=True)
@click.option('--seq-prefix', help='first seq-prefix clusters will be migrated sequentially', callback=check_positive, default=50)
@click.option('--threads', help='how many clusters can be being migrated simultaneously', default=1)
@click.option('--update-cluster/--no-update-cluster', default=False)
@click.option('--one-by-one/--no-one-by-one', default=True)
@click.option('--infra-severity', help='major/minor', default='major')
def main(env_name: str, migrations_dir: str, seq_prefix: int, threads: int, update_cluster: bool = False,
            one_by_one: bool = True, infra_severity: str = 'major'):
    '''
        \b
        Example:
            ./migrate_shards --env test --migrations-dir $ARCADIA_ROOT/mail/pg/mdb --seq-prefix 50 --threads 8
    '''
    setup_logger()
    env: Env = Envs[env_name].value
    users = get_users_from_yav(env.users_file)
    owner_passwd = users[OWNER_NAME]['password']
    yc = YcClient(cloud_id=env.cloud_id, folder_name=env.folder_name)
    # TODO :: get clusters from shiva.shards
    # current implementation races with concurrent cluster creations
    clusters = yc.get_clusters()

    last_migration_file = sorted(os.listdir(
        os.path.join(os.path.expanduser(migrations_dir), 'migrations')
    ))[-1]
    log.info(green(f'Last migration is: {last_migration_file}\n'))
    canary_clusters, main_clusters = split(clusters, pred=lambda cluster: cluster['name_info'].get('suffix') == 'canary')

    infra_event = create_event(env.infra, infra_severity)
    setup_progress_bar(len(clusters))

    if canary_clusters:
        log.info(green('Canary clusters found, going to migrate them first'))
        wait_for_enter(green('Press Enter to start: '))
        migrate_clusters(migrations_dir, owner_passwd, update_cluster, users, yc, one_by_one, canary_clusters)
        wait_for_enter(green('Canary migration is finished, press Enter to proceed to main cluster set: '))

    log.info(green('Starting migration of the sequential prefix of clusters...'))
    migrate_clusters(migrations_dir, owner_passwd, update_cluster, users, yc, one_by_one, main_clusters[:seq_prefix])

    if len(main_clusters) > seq_prefix:
        wait_for_enter(green('Done, press Enter to start the parallel part: '))
        log_queue, consumer = start_log_consumer(threads * 100)

        migrate_tmpl = partial(migrate_clusters, migrations_dir, owner_passwd, update_cluster, users, yc, one_by_one, prompt_on_exception=False)
        with Pool(threads, setup_worker, (log_queue, lock,)) as pool:
            transformed = [[cluster] for cluster in main_clusters[seq_prefix:]]
            pool.map(migrate_tmpl, transformed, chunksize=1)

        stop_log_consumer(log_queue, consumer)

    finish_event(infra_event)


def migrate_clusters(migrations_dir, owner_passwd, update_cluster, users, yc, one_by_one, clusters, prompt_on_exception=True):
    logger = BufferingLogger(log, lock)
    logger.info(green(f'\nYou are going to migrate {len(clusters)} following clusters:'))
    logger.info(yaml.dump([
        {k: cluster[k] for k in (
            'name',
            'id',
            'description',
        )}
        for cluster in clusters
    ], indent=4, sort_keys=False))
    for cluster in clusters:
        try:
            migrate_cluster(cluster, migrations_dir, owner_passwd, update_cluster, users, yc, one_by_one, logger)
        except Exception as e:
            logger.exception(f'An exception occured during migration:\n{str(e)}')
            if prompt_on_exception:
                logger.flush()
                wait_for_enter('Press Enter to continue: ')
        finally:
            update_progress(log)
            logger.flush()


def migrate_cluster(cluster, migrations_dir, owner_passwd, update_cluster, users, yc, one_by_one, logger):
    if update_cluster:
        yc.update_users(
            cluster_name=cluster["name"],
            users=[
                User(uname, **opts)
                for uname, opts in users.items()
            ]
        )
        yc.create_extensions(cluster_name=cluster["name"], extensions=EXTENSIONS)
        yc.update_settings(cluster_name=cluster["name"], **POSTGRE_CONFIG_OPTS)
    logger.info(f'\nApplying migration to {cluster["name"]}...')
    master_host = f'c-{cluster["id"]}.rw.db.yandex.net'
    pgmigrate(
        basedir=migrations_dir,
        host=master_host,
        user=OWNER_NAME,
        logger=logger,
        passwd=owner_passwd,
        one_by_one=one_by_one,
    )


if __name__ == '__main__':
    main()
