import os
import sys
import struct
import signal
import logging
import resource
import subprocess
import errno
import platform

import gevent.socket as socket
import msgpack

from ..kernel_util.sys.user import UserPrivileges
from . import fdtools
from .exceptions import ProcStartException


def build_limits(limits):
    valid = {}
    for key, lim in limits.iteritems():
        try:
            new_key = getattr(resource, "RLIMIT_" + key.upper())
        except AttributeError:
            raise ValueError("Invalid limit name: {}".format(key))

        if isinstance(lim, (tuple, list)):
            soft, hard = lim
        else:
            soft = hard = lim

        if hard == sys.maxint or hard is None:
            hard = resource.RLIM_INFINITY
        elif hard == -1:
            hard = resource.getrlimit(new_key)[1]

        if soft == sys.maxint or soft is None:
            soft = resource.RLIM_INFINITY
        elif soft == -1:
            soft = hard

        valid[key] = (soft, hard)

    return valid


def setlimits(username, limits):
    limits = build_limits(limits)

    for key, val in limits.items():
        key = getattr(resource, "RLIMIT_" + key.upper())
        resource.setrlimit(key, val)


def set_cgroup(pid, cgroups):
    if not cgroups:
        # there is nothing to do
        return

    for cgroup in cgroups:
        tasks_path = os.path.join(cgroup, 'tasks')

        if os.path.exists(tasks_path) and os.path.isfile(tasks_path):
            try:
                with open(tasks_path, 'wb') as fd:
                    fd.write('{}\n'.format(pid))
            except Exception as ex:
                if isinstance(ex, IOError) and ex.errno == errno.EPERM and 'Microsoft' in platform.release():
                    # skip it for WSL
                    pass
                else:
                    raise Exception("unable to set cgroup %s, %s" % (tasks_path, ex))


def remove_logging_streams():
    logger = logging.getLogger()
    for handler in logger.handlers:
        if hasattr(handler, 'stream'):
            if handler.__class__.__name__ == 'StreamHandler':
                logger.removeHandler(handler)
            else:
                handler.stream = None


def runproc(
    log, args, env, cwd, username,
    limits=None, cgroups=None,
    catch_outs=False,
):
    if limits is None:
        limits = {}

    def prefunc():
        remove_logging_streams()

        os.chdir(cwd)
        pid = os.getpid()

        set_cgroup(pid, cgroups)

        setlimits(username, limits)

        if username is not None:
            UserPrivileges(username, store=False, limit=False, modifyGreenlet=False).__enter__()

    return procopen(prefunc, args, env, catch_outs)


def procopen(prefunc, args, env, catch_outs):
    rpipe, wpipe = fdtools.pipe()

    if catch_outs:
        out_rpipe, out_wpipe = fdtools.pipe()
        err_rpipe, err_wpipe = fdtools.pipe()
    else:
        out_rpipe, out_wpipe, err_rpipe, err_wpipe = None, None, None, None

    def _safe_prefunc(wpipe=wpipe, prefunc=prefunc):
        try:
            prefunc()
        except Exception:
            import traceback
            msg = msgpack.dumps((False, traceback.format_exc()))
            os.write(wpipe, struct.pack('!I', len(msg)))
            os.write(wpipe, msg)
            os._exit(1)
        else:
            msg = msgpack.dumps((True, ))
            os.write(wpipe, struct.pack('!I', len(msg)))
            os.write(wpipe, msg)

    pid = os.fork()

    if not pid:
        signal.alarm(300)
        # Close read pipe in child process
        del rpipe
        del out_rpipe
        del err_rpipe
        try:
            child(_safe_prefunc, args, env, out_wpipe, err_wpipe)
        finally:
            os._exit(-1)

    # Close write pipe in main process
    del wpipe
    del out_wpipe
    del err_wpipe

    try:
        socket.wait_read(rpipe, 120.0)
        msglen_raw = os.read(rpipe, 4)

        if not msglen_raw or len(msglen_raw) != 4:
            raise Exception("failed to read message length")

        msglen = struct.unpack('!I', msglen_raw)[0]

        buf = ''
        while len(buf) < msglen:
            socket.wait_read(rpipe, 120.0)
            data = os.read(rpipe, msglen - len(buf))
            if not data:
                break
            buf += data

        if len(buf) != msglen:
            raise Exception("died before we received message")

        msg = msgpack.loads(buf)
        if not msg[0]:
            raise Exception(msg[1])

        return pid if not catch_outs else (pid, out_rpipe, err_rpipe)
    except Exception as e:
        try:
            os.killpg(pid, signal.SIGKILL)
        except EnvironmentError:
            pass

        raise ProcStartException(pid, "Unable to communicate with forked child: %s" % (e,))


def child(prefunc, args, env, out_wpipe, err_wpipe):
    devnullin = os.open(os.devnull, os.O_RDONLY)
    devnullout = os.open(os.devnull, os.O_WRONLY) if None in (out_wpipe, err_wpipe) else None

    prefunc()

    # dup our pipe descriptors to stdin/stdout/stderr
    os.dup2(devnullin, 0)
    os.dup2(devnullout if out_wpipe is None else out_wpipe, 1)
    os.dup2(devnullout if err_wpipe is None else err_wpipe, 2)

    os.closerange(3, subprocess.MAXFD)

    # restore signal handlers
    signal.signal(signal.SIGHUP, signal.SIG_DFL)
    signal.alarm(0)

    os.execve(args[0], args, env)
