import logging
import multiprocessing
import os
import signal
import threading

from frozendict import frozendict
from six.moves import queue

from crypta.lib.python.worker_utils.inflight_tracker import InflightTracker
from crypta.lib.python.worker_utils.multiprocessing_metrics_registry import MultiprocessingMetricRegistry
from crypta.lib.python.worker_utils.task_metrics import TaskMetrics
from crypta.lib.python.worker_utils.worker_config import WorkerConfig

logger = logging.getLogger(__name__)


class TaskQueue(object):
    def __init__(self, cookie_queue, context, worker_count, worker_cls, metrics, common_labels=None):
        self.cookie_queue = cookie_queue
        self.mp_metric_registry = MultiprocessingMetricRegistry(common_labels)
        self.done_queue = multiprocessing.Queue()
        self.task_queue = multiprocessing.Queue()

        self.worker_count = worker_count
        self.worker_cls = worker_cls

        self.lock = threading.Lock()
        self.tracker = InflightTracker()
        self.running = multiprocessing.Value('i', 0)

        metrics = {
            frozendict(labels): TaskMetrics(self.mp_metric_registry, labels)
            for labels in metrics
        }

        self.worker_config = WorkerConfig(self.done_queue, self.task_queue, self.mp_metric_registry, self.running, context, metrics)

    def schedule(self, msgs, cookie):
        if not msgs:
            logger.info("Nothing to process, committing: %s", cookie)
            self._commit(cookie)
            return

        with self.lock:
            tasks = self.tracker.register(msgs, lambda: self._commit(cookie))

            for task in tasks:
                self.send_to_task_queue(task)

    def _commit(self, cookie):
        self.cookie_queue.put_nowait(cookie)

    def __enter__(self):
        self.running.value = 1
        self.processes = []

        for _ in range(self.worker_count):
            process = multiprocessing.Process(target=self._run_worker)
            process.start()
            self.processes.append(process)

        self.complete_thread = threading.Thread(target=self._check_complete_loop)
        self.complete_thread.start()

        self.mp_metric_registry.start()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.running.value = 0
        for process in self.processes:
            logger.info("Joining process %s", process.name)
            process.join(timeout=10)
            if process.exitcode is None:
                os.kill(process.pid, signal.SIGKILL)
                process.join()

        logger.info("Joining completion thread")
        self.complete_thread.join()
        logger.info("Done joining")

        logger.info("Stop mp_metric_registry")
        self.mp_metric_registry.stop()
        logger.info("Done stopping")

    def _check_complete_loop(self):
        while self.running.value:
            try:
                task = self.done_queue.get(timeout=1)
                with self.lock:
                    self.tracker.complete(task)

            except queue.Empty:
                pass

    def send_to_task_queue(self, task):
        self.task_queue.put_nowait(task)

    def _run_worker(self):
        try:
            worker = self.worker_cls(self.worker_config)
            worker.run()
        except:
            logger.exception("worker failed")
