from crypta.lib.python.zk import (
    zk_client,
    fake_zk_client,
)
from crypta.lib.python.bt.commons.cli.shell import writeln, colorize, bold
import crypta.lib.python.bt.conf.conf as conf

import crypta.lib.python.bt.workflow as workflow
from setproctitle import (
    setproctitle,
)

from pprint import pformat
import uuid
import multiprocessing
import time
import logging
import pdb
import sys
import traceback
import signal
from six.moves import range


logger = logging.getLogger(__name__)


OK_CODE = 0
FAIL_CODE = 2


def zk_hosts():
    return conf.proto.Zk.Host.split(",")


def zk_port():
    return conf.proto.Zk.Port


def esv_to_list(args):
    return [arg.split('=', 1) for arg in args]


def create_task(args):
    task_name = args.Task
    parameters = esv_to_list(args.Arg)

    try:
        task = workflow.task_instance(task_name, **dict(parameters))
        return task
    except workflow.ParameterMissed as e:
        writeln(colorize('Parameter is missed: {}'.format(e), color=1))
        return None
    except workflow.TaskMissed as e:
        writeln(colorize('No such task: {}'.format(e), color=1))
        return None


def handle_task(args, local):
    task_args = getattr(args, args.WhichOneof("Task"))
    task = create_task(task_args)

    if not task:
        return FAIL_CODE

    if local:
        with fake_zk_client() as fake_zk:
            if task_args.Force:
                workflow.scha.execute_sync_force(task, fake_zk)
            else:
                workflow.scha.execute_sync(task, fake_zk, do_fork=False)
    else:
        logging.disable(logging.WARNING)
        with zk_client(hosts=zk_hosts(), port=zk_port()) as zk:
            workflow.scha.execute_async(task, zk)
    writeln(colorize('Enqueued {}'.format(task), color=2))

    return OK_CODE


def handle_workers(args):
    # Check if all child processes are alive every CHECK_INTERVAL seconds
    CHECK_INTERVAL = 10
    INFINITE_LOOP = True

    tags_counts = esv_to_list(args.Workers.Worker)
    tags_counts = [(x[0], int(x[1])) for x in tags_counts]

    def run_worker(worker_id, worker_tag):
        setproctitle('Worker for tag [{}] (ID {})'.format(worker_tag, worker_id))
        while INFINITE_LOOP:
            try:
                with zk_client(hosts=zk_hosts(), port=zk_port()) as zk:
                    workflow.scha.run_worker(zk, worker_id=worker_id, tag=worker_tag)
            except Exception:
                logger.exception('Worker %s failed', worker_id)

    def spawn_process(worker_tag):
        logger.info('Spawning worker process')
        worker_id = uuid.uuid4()
        process = multiprocessing.Process(target=run_worker, args=(worker_id, worker_tag))
        process.id = worker_id
        process.tag = worker_tag
        return process

    def loop(processes):
        while INFINITE_LOOP:
            # Sleep to mitigate active wait cpu impact
            time.sleep(CHECK_INTERVAL)
            for process in list(processes):
                if not process.is_alive():
                    logger.info('Worker %s, pid %d is dead, restarting',
                                process.id, process.pid)
                    processes.remove(process)
                    replacement_process = spawn_process(process.tag)
                    logger.info('Starting worker %s', replacement_process.id)
                    replacement_process.start()
                    logger.info('Process started, worker %s, pid %d',
                                replacement_process.id,
                                replacement_process.pid)
                    processes.append(replacement_process)

    def shutdown(processes):
        def term(signo, _stack_frame):
            for process in processes:
                logger.info('Shutdown worker %s pid %d due to signal %d',
                            process.id, process.pid, signo)
                process.terminate()
            raise workflow.WorkerStopped()
        return term

    if not tags_counts:
        logger.exception('At least one worker tag should be specified')
        return FAIL_CODE

    processes = sum((
        [spawn_process(tag) for _ in range(count)]
        for (tag, count) in tags_counts
    ), [])
    for process in processes:
        logger.info('Starting process %s', process.id)
        process.start()

    signal.signal(signal.SIGTERM, shutdown(processes))
    signal.signal(signal.SIGINT, shutdown(processes))

    loop(processes)


def handle_status(args):
    logging.disable(logging.WARNING)
    colors = {workflow.Task.Status.LOCKED: 2,
              workflow.Task.Status.COOLING_DOWN: 1}

    with zk_client(hosts=zk_hosts(), port=zk_port()) as zk:
        status = workflow.scha.status(zk)
        for key in status.keys():
            formatted = pformat({key: [str(x) for x in status[key]]})
            writeln(colorize(formatted, color=colors.get(key, 7)))
    return OK_CODE


def handle_list(args):
    logging.disable(logging.WARNING)
    sorted_tasks = sorted(workflow.available_tasks(), key=lambda x: x[0])
    for task_name, task in sorted_tasks:
        declaration = task.declaration()
        short_description = '%s ...' % (task.description().split('\n')[0])
        writeln('%s: %s' % (colorize(bold(declaration), color=2),
                            short_description))
    return OK_CODE


def handle_requirements(args):
    logging.disable(logging.WARNING)
    task = create_task(args.Requirements)
    if not task:
        return FAIL_CODE

    for each in workflow.transitive_requirements(task):
        writeln(colorize(bold(each), color=2))

    return OK_CODE


def handle_purge(args):
    logging.disable(logging.WARNING)
    with zk_client(hosts=zk_hosts(), port=zk_port()) as zk:
        workflow.scha.purge(zk)

    writeln(colorize('Purged all tasks', color=1))

    return OK_CODE


def main(args):
    try:
        task_name = args.WhichOneof("Task")

        if task_name == 'Workers':
            return handle_workers(args)
        elif task_name == 'Status':
            return handle_status(args)
        elif task_name == 'Run':
            return handle_task(args, local=True)
        elif task_name == 'Enqueue':
            return handle_task(args, local=False)
        elif task_name == 'List':
            return handle_list(args)
        elif task_name == 'Requirements':
            return handle_requirements(args)
        elif task_name == 'Purge':
            return handle_purge(args)
        else:
            raise RuntimeError("Unknown command")
    except KeyboardInterrupt:
        writeln(colorize('Stopped', color=1))
        return FAIL_CODE
    except Exception as e:
        if args.Debug:
            type, value, tb = sys.exc_info()
            traceback.print_exc()
            pdb.post_mortem(tb)
        else:
            exc_class = e.__class__.__name__
            writeln(colorize('Error happened: {} ({})'.format(exc_class, str(e)), color=1))
            return FAIL_CODE

    return FAIL_CODE
