import logging
from typing import Dict

import psutil
from multiprocessing import Process, Queue
from queue import Empty as QueueEmpty

from psutil import STATUS_ZOMBIE, STATUS_DEAD

from mail.husky.husky.husky_subproc import main as husky_loop
from mail.husky.stages.worker.interactions.db.huskydb import HuskydbTransfer
from mail.husky.stages.worker.settings.task_dispatcher import TaskDispatcherSettings
from mail.python.theatre.roles import Cron
from ora2pg.app.transfer_app import TransferApp

log = logging.getLogger(__name__)


def init_workers(app, in_q, out_q, cnt) -> Dict[int, psutil.Process]:
    """Initialize workers and return dict of (pid:worker) pairs"""
    def get_worker_pid():
        w = Process(
            target=husky_loop,
            args=(app, in_q, out_q)
        )
        w.start()
        return w.pid

    return dict((pid, psutil.Process(pid)) for pid in (get_worker_pid() for _ in range(cnt)))


def reanimate_dead_workers(app, workers, transfer_try_resp_q, users_to_transfer_q, target_workers_cnt):
    workers = remove_dead_workers(workers)
    if len(workers) < target_workers_cnt:
        log.info('Run new workers in place of dead ones')
        workers.update(
            init_workers(
                app,
                users_to_transfer_q,
                transfer_try_resp_q,
                target_workers_cnt - len(workers)
            )
        )
    return workers


def remove_dead_workers(workers):
    """Find dead worker ids and remove them from dictionary"""
    for worker_id, w in list(workers.items()):
        if w.is_running() and w.status() in (STATUS_ZOMBIE, STATUS_DEAD):
            exit_code = w.wait()
            log.info('Worker #%d exited with exit_code %d', worker_id, exit_code)
            del workers[worker_id]

    dead_workers = [
        worker_id for worker_id, w in workers.items()
        if not w.is_running()
    ]
    for worker_id in dead_workers:
        log.error('Sadly, worker #%d is dead', worker_id)
        del workers[worker_id]
    return workers


def shutdown_workers(workers, in_q, wait=False):
    """Send "stop working" message to all workers and join them"""

    for _ in range(len(workers)):
        in_q.put(None)

    if wait:
        gone, still_alive = psutil.wait_procs(workers.values(), timeout=3)
    else:
        still_alive = workers.values()
    for p in still_alive:
        p.kill()
    psutil.wait_procs(still_alive, timeout=1)


async def process_out_q(huskydb_transfer, out_q):
    """Read all responses from out_q and update that users statuses"""
    while True:
        try:
            resp = out_q.get(block=False)
            await huskydb_transfer.update_user_status(resp)
        except QueueEmpty:
            break


def append_in_q(in_q, users):
    """Add every user to input queue"""

    for user in users:
        log.debug('Got %r, putting it in queue', user)
        in_q.put(user)


class TaskDispatcher(Cron):
    def __init__(self,
                 app: TransferApp,
                 shard_id: int,
                 huskydb_adaptor: HuskydbTransfer,
                 settings: TaskDispatcherSettings) -> None:
        self.shard_id = shard_id
        self.settings = settings
        # State
        self.app = app
        self.tasks_q = Queue()
        self.results_q = Queue()
        self.target_workers = self.max_workers = self.settings.worker_cnt
        self.huskydb = huskydb_adaptor
        self.workers: Dict[int, psutil.Process] = {}

        super().__init__(job=self.dispatch, run_every=settings.cron.run_every)

    async def start(self):
        self.workers = init_workers(
            app=self.app,
            in_q=self.tasks_q,
            out_q=self.results_q,
            cnt=self.target_workers,
        )
        await super().start()
        await self.huskydb.return_in_progress_tasks()

    async def stop(self, wait=True):
        shutdown_workers(self.workers, self.tasks_q, wait=wait)
        await super().stop(wait=wait)
        await process_out_q(self.huskydb, self.results_q)

    def add_worker(self) -> None:
        if self.target_workers >= self.max_workers:
            return
        self.target_workers += 1

    def remove_worker(self) -> None:
        if self.target_workers <= 0:
            return
        self.target_workers -= 1
        self.tasks_q.put(None)

    def update_max_workers(self, max_workers: int) -> None:
        self.max_workers = max_workers
        for _ in range(max(0, self.target_workers - self.max_workers)):
            self.remove_worker()

    async def dispatch(self):
        self.workers = reanimate_dead_workers(
            self.app,
            self.workers,
            self.results_q,
            self.tasks_q,
            self.target_workers,
        )

        while self.tasks_q.qsize() < self.settings.max_queue_size:
            task_chunk = await self.huskydb.acquire_tasks(self.settings.chunk_size)
            if not task_chunk:
                break
            log.debug('Got tasks: %s', ''.join(str(u.transfer_id) for u in task_chunk))
            append_in_q(self.tasks_q, task_chunk)

        # TODO :: Process out-q in separate place
        await process_out_q(self.huskydb, self.results_q)
