from __future__ import division, absolute_import

import gevent
import importlib
import os
import traceback
from .connectors import ConductorHTTP
from .installers import DebianInstaller, RedhatInstaller
from .installers.errors import IgnorableError
from .zk import ZK


class Supervisor(object):
    def __init__(self, ctx):
        self.ctx = ctx
        self.cfg = self.ctx.cfg

        if 'zookeeper' in self.cfg:
            self.zk = ZK(ctx, self.cfg['zookeeper'])
        else:
            self.zk = None

        self.connector = ConductorHTTP(ctx=ctx)
        installer = self.getInstallerClass()
        self.installer = installer(ctx=ctx)
        self.refreshInterval = self.cfg['connector'].get('refreshInterval', 20)
        self.tasks = {}
        self._runGrn = None
        self._parentPollGrn = None
        self._shouldStop = False

    def getInstallerClass(self):
        if 'installer' in self.cfg and 'type' in self.cfg['installer']:
            t = self.cfg['installer']['type']
            if t == 'debian':
                self.ctx.log.info('Using debian installer from config')
                return DebianInstaller
            elif t == 'redhat':
                self.ctx.log.info('Using redhat installer from config')
                return RedhatInstaller
            else:
                self.ctx.log.info('Using {} as installer from config'.format(t))
                paths = t.split('.')
                paths, name = '.'.join(paths[:-1]), paths[-1]
                mod = importlib.import_module(paths)
                return getattr(mod, name)

        # fallback
        if os.path.exists('/usr/bin/yum') and not os.path.exists('/usr/bin/apt-get'):
            self.ctx.log.info('Using a yum-based installer')
            return RedhatInstaller
        self.ctx.log.info('Using a default apt installer')
        return DebianInstaller

    def acceptTask(self, task):
        if task.id in self.tasks:
            oldTask = self.tasks[task.id]
            task.attempts = oldTask.attempts
            task.lastAttempt = oldTask.lastAttempt
        self.tasks[task.id] = task

    def forgetTask(self, task):
        self.tasks.pop(task.id)

    def start(self):
        assert self._runGrn is None
        self._runGrn = gevent.spawn(self._runLoop)
        self._parentPollGrn = gevent.spawn(self._parentPollLoop)
        return self

    def stop(self):
        assert self._runGrn is not None
        assert self._parentPollGrn is not None
        self._runGrn.kill(gevent.GreenletExit)
        self._parentPollGrn.kill(gevent.GreenletExit)
        return self

    def join(self):
        self._runGrn.join()

    def _parentPollLoop(self):
        while True:
            if os.getppid() == 1:
                self.ctx.log.warn('Parent PID is 1, so we should exit')
                self._shouldStop = True
                return
            gevent.sleep(self.refreshInterval)

    def _runLoop(self):
        while not self._shouldStop:
            self.process()
            gevent.sleep(self.refreshInterval)

    def _setStatus(self, task, newStatus, log):
        """
        :type task: conductor.agent.task.Task
        :type newStatus: str
        :type log: str
        :rtype (bool, list[conductor.agent.task.Task])
        """
        result = self.connector.status(task, newStatus, log)
        if not result and not task.hasRememberedStatus():
            task.rememberStatus(newStatus, log)
        if result and task.hasRememberedStatus():
            task.clearRememberedStatus()
        return result.result, result.tasks

    def process(self):
        try:
            tasks = self.connector.refresh()
            self.ctx.log.info("Got %d tasks for me", len(tasks))
        except:
            msg = "Error getting tasks: %s" % traceback.format_exc()
            self.ctx.log.error(msg)
            return

        for task in tasks:
            self.acceptTask(task)

        while len(tasks):
            task = tasks.pop(0)

            if task.hasRememberedStatus():
                self.ctx.log.debug("Task %s has remembered status %s", task.task, task.newStatus)
                ok, new_tasks = self._setStatus(task,
                                                task.newStatus,
                                                "WARN: This is status repost\n" + task.newStatusLog)
                if ok:
                    if task.newStatus == 'done' or task.newStatus == 'failed':
                        self.forgetTask(task)
                    for t in new_tasks:
                        self.acceptTask(t)
                    tasks = new_tasks + tasks
                    self.ctx.log.debug("Successfully reposted status for task %s, got %d new tasks",
                                       task.task,
                                       len(new_tasks))
                continue
            if not task.has_work():
                self.ctx.log.debug("Skipped task %s. Nothing to do for status %s", task.task, task.status)
                continue
            if not task.can_run_now():
                self.ctx.log.debug("Skipping task %s because of the schedule", task.task)
                continue
            if task.secondsSinceLastAttempt() < self.cfg['attemptInterval']:
                self.ctx.log.debug("Skipped task %s. It's time hasn't come yet (last of %d attempts was %ds ago)", task.task, task.attempts, task.secondsSinceLastAttempt())
                continue

            if self.zk and task.max_parallel:
                self.zk.lock(task.task, task.max_parallel)

            logs = []
            new_tasks = []
            try:
                self.installer.doTask(task, logs)
                ok, new_tasks = self._setStatus(task, 'done', "\n".join(logs))
                if ok:
                    self.forgetTask(task)
                if len(new_tasks):
                    for new_task in new_tasks:
                        self.acceptTask(new_task)
                    tasks = new_tasks + tasks
                    self.ctx.log.info("Adding %d tasks after status update", len(new_tasks))
            except Exception as e:
                msg = "Error running task %s: %s" % (task.task, traceback.format_exc())
                logs.append(msg)
                self.ctx.log.error(msg)
                if not isinstance(e, IgnorableError):
                    task.incrementAttempts()
                if task.attempts >= self.cfg['attemptCount']:
                    msg = "Task %s could not be done in %d attempts. We'll give up." % (task.task, task.attempts)
                    logs.append(msg)
                    self.ctx.log.error(msg)
                    ok, _ = self._setStatus(task, 'failed', "\n".join(logs))
                    if ok:
                        self.forgetTask(task)
                else:
                    self._setStatus(task, 'error', "\n".join(logs))

            if self.zk:
                found = False
                for t in new_tasks:
                    if t.task == task.task:
                        found = True
                        break
                if not found:
                    self.zk.unlock(task.task)
