import os
import abc
import uuid
import getpass
import logging
import datetime as dt
import contextlib

import gevent
import subprocess32

from sandbox import common

import sandbox.taskbox.statistics as tb_statistics
import sandbox.taskbox.client.service as tb_service


logger = logging.getLogger(__name__)


class WorkerProc(object):
    __metaclass__ = abc.ABCMeta

    @abc.abstractproperty
    def pid(self):
        pass

    @abc.abstractproperty
    def uuid(self):
        pass

    @abc.abstractmethod
    def is_dead(self):
        pass

    @abc.abstractmethod
    def terminate(self):
        pass

    @abc.abstractmethod
    def get_stderr(self):
        pass


class Worker(object):
    def __init__(self, worker_proc, worker_id, tasks_binary_id, socket_path, system_collector_interval):
        self._proc = worker_proc
        self.worker_id = worker_id
        self.tasks_binary_id = tasks_binary_id
        self.socket_path = socket_path
        self.system_collector_interval = system_collector_interval

        self.requests_in_progress = 0

        self.logger = logging.getLogger("worker-{}".format(self.tasks_binary_id))
        self._client = tb_service.Worker(self.socket_path, logger=self.logger)

        self._stderr = None
        self.__setup_statistics_collection()
        self.last_call_time = dt.datetime.utcnow()

    def __setup_statistics_collection(self):
        self.system_statistics = None
        self.system_statistics_greenlet = None

        if self.system_collector_interval is not None:
            self.system_statistics = common.os.SystemStatistics(pid=self.pid)
            self.system_statistics_greenlet = gevent.spawn(
                tb_statistics.system_statistics_collector,
                statistics=self.system_statistics,
                interval=self.system_collector_interval,
                logger=self.logger,
                stop_checker=self.is_dead,
            )

    def __repr__(self):
        uuid_ = self.uuid
        if isinstance(uuid_, str):
            uuid_ = uuid_[:8]
        return "<Worker #{} pid={} uuid={}>".format(self.tasks_binary_id, self.pid, uuid_)

    @property
    def pid(self):
        return self._proc.pid

    @property
    def uuid(self):
        return self._proc.uuid

    def is_dead(self):
        return self._proc.is_dead()

    def terminate(self):
        self._proc.terminate()

    def check_if_dead(self):
        is_dead = self.is_dead()
        if is_dead and self._stderr is None:
            self._stderr = self._proc.get_stderr()
            self.logger.warning("Worker is down:\n%s", self._stderr)
        return is_dead

    @contextlib.contextmanager
    def unlock_after_use(self):
        try:
            yield self
        finally:
            self.requests_in_progress -= 1

    def call(self, request_data, request_id):
        return self._client.call(request_data, request_id)


class WorkerFactory(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self, system_collector_interval):
        self.system_collector_interval = system_collector_interval

    def get_worker_env(self):
        env = os.environ.copy()
        env["Y_PYTHON_ENTRY_POINT"] = "sandbox.taskbox.worker.cli:main"

        config_env_var = env[common.config.Registry.CONFIG_ENV_VAR]
        env[common.config.Registry.CONFIG_ENV_VAR] = os.path.abspath(config_env_var)

        return env

    @abc.abstractmethod
    def find_workers(self):
        pass

    @abc.abstractmethod
    def spawn_worker(self, worker_id, tasks_bin, tasks_binary_id, socket_path):
        pass


class ProcmanWorkerProc(WorkerProc):
    def __init__(self, proc):
        self._proc = proc

    @property
    def tags(self):
        return self._proc.get_tags()

    @property
    def pid(self):
        return self._proc.stat().get("pid")

    @property
    def uuid(self):
        return self._proc.stat().get("uuid")

    def is_dead(self):
        import api.procman
        return self._proc.status() == api.procman.FINISHED

    def terminate(self):
        self._proc.terminate()

    def get_stderr(self):
        return self._proc.stderr()


class ProcmanFactory(WorkerFactory):
    WORKER_PROC_TAG = "taskbox_worker"

    class Tags(common.utils.Enum):
        common.utils.Enum.lower_case()

        WID = None  # Worker ID
        TID = None  # Tasks binary ID
        SOCKET = None

    def find_workers(self):
        import api.procman
        workers = {}
        for proc in api.procman.ProcMan().find_by_tags([self.WORKER_PROC_TAG]):
            # Get worker properties from process tags
            params = dict(tag.split("=", 1) for tag in proc.get_tags() if tag != self.WORKER_PROC_TAG)
            worker_id = params[self.Tags.WID]
            tasks_binary_id = int(params[self.Tags.TID])
            socket_path = params[self.Tags.SOCKET]

            worker = Worker(
                ProcmanWorkerProc(proc), worker_id, tasks_binary_id, socket_path, self.system_collector_interval
            )
            workers.setdefault(tasks_binary_id, []).append(worker)
            logger.info("Worker found: %r", worker)

        return workers

    def spawn_worker(self, worker_id, tasks_bin, tasks_binary_id, socket_path):
        import api.procman
        proc = api.procman.ProcMan().create(
            [
                tasks_bin,
                "--id", str(worker_id),
                "--tasks-binary-id", str(tasks_binary_id),
                "--socket", socket_path,
            ],
            keeprunning=False,
            liner=True,
            tags=[
                self.WORKER_PROC_TAG,
                "{}={}".format(self.Tags.WID, worker_id),
                "{}={}".format(self.Tags.TID, tasks_binary_id),
                "{}={}".format(self.Tags.SOCKET, socket_path),
            ],
            env=self.get_worker_env(),
            user=getpass.getuser(),
            cwd=os.getcwd(),
        )

        return Worker(
            ProcmanWorkerProc(proc), worker_id, tasks_binary_id, socket_path, self.system_collector_interval
        )


class SubprocessWorkerProc(WorkerProc):
    def __init__(self, proc, proc_id):
        self._proc = proc
        self._uuid = proc_id

        self._stdout = None
        self._stderr = None

    @property
    def pid(self):
        return self._proc.pid

    @property
    def uuid(self):
        return self._uuid

    def is_dead(self):
        return self._proc.poll() is not None

    def terminate(self):
        self._proc.terminate()

    def get_stderr(self):
        if not self.is_dead():
            raise Exception("Can't get stderr from an alive process")

        if self._stderr is None:
            result = self._proc.communicate()
            self._stdout, self._stderr = result

        return self._stderr


class SubprocessFactory(WorkerFactory):
    def find_workers(self):
        return {}

    def spawn_worker(self, worker_id, tasks_bin, tasks_binary_id, socket_path):
        proc = subprocess32.Popen(
            [
                tasks_bin,
                "--id", str(worker_id),
                "--tasks-binary-id", str(tasks_binary_id),
                "--socket", socket_path,
            ],
            env=self.get_worker_env(),
            # TODO: spawn thread for reading stderr,
            # process can block if it would write a lot
            stderr=subprocess32.PIPE,
            # Just ignore stdout for now
            stdout=open(os.devnull, "wb"),
        )
        proc_id = str(uuid.uuid4())

        return Worker(
            SubprocessWorkerProc(proc, proc_id), worker_id, tasks_binary_id, socket_path, self.system_collector_interval
        )
