import os
import pwd
import types
import signal
import subprocess
import collections
from functools import partial
from contextlib import contextmanager

from ...exceptions import CQueueExecutionFailure
from ...utils import monotime

from ya.skynet.library.tasks.shell import findShell


STDOUT = 1  # output labels
STDERR = 2


def start_cmd(cmd, extra_env=None, in_rtc_container=False):
    stdout, stderr = (subprocess.PIPE, subprocess.PIPE)

    sh = _user_default_shell(in_rtc_container=in_rtc_container)
    env = os.environ.copy()
    env.pop('YP_TOKEN', None)
    env['SHELL'] = sh
    if extra_env:
        env.update(extra_env)

    return subprocess.Popen(
        cmd,
        bufsize=-1,
        env=env,
        shell=True,
        executable=sh,
        stdout=stdout,
        stderr=stderr,
        close_fds=True,
        preexec_fn=lambda: signal.signal(signal.SIGPIPE, signal.SIG_DFL),
    )


def _user_default_shell(in_rtc_container=False):
    try:
        sh = pwd.getpwuid(os.getuid()).pw_shell
    except KeyError:
        sh = findShell()
    else:
        if not os.path.exists(sh) or not os.path.isfile(sh):
            sh = findShell()

    if in_rtc_container and sh and os.path.basename(sh) in ('false', 'nologin'):
        candidates = (
            '/usr/local/bin/bash',
            '/usr/bin/bash',
            '/bin/bash',
            '/bin/sh',
        )
        # TODO should we also mount /portoshell_utils/ as well?
        fallback = '/portoshell_utils/busybox sh'

        for candidate in candidates:
            if os.path.exists(candidate) and os.path.isfile(candidate):
                sh = candidate
                break
        else:
            sh = fallback

    return sh


class PidIteratorWrapper(object):
    """
    Check pid upon iteration to prevent brain damage
    from/to people who'll call `os.fork()` and yield
    something.
    """

    def __init__(self, it=None):
        self._pid = os.getpid()
        self._it = it

    def set_object(self, it):
        if self._it is not None and hasattr(self._it, 'set_object'):
            self._it.set_object(it)
        else:
            self._it = it

    def __iter__(self):
        return self

    def __funcWrapper(self, fun, *args, **kwargs):
        try:
            return fun(*args, **kwargs)
        finally:
            if os.getpid() != self._pid:
                # if one forks and yields something,
                # they must DIE with pain
                os._exit(0)

    def __getattr__(self, arg):
        res = getattr(self._it, arg)
        if not isinstance(res, collections.Callable):
            return res
        return partial(self.__funcWrapper, res)

    def next(self):
        return self.__getattr__('next')()

    def __next__(self):
        return self.__getattr__('__next__')()


class Wrapper(object):
    def __init__(self, obj, on_ready=None):
        self.obj = obj
        self.on_ready = on_ready

    def __str__(self):
        return self.obj.__str__()

    def __repr__(self):
        return self.obj.__repr__()

    def __getattr__(self, name):
        if name == 'obj':
            # It means that object is not instantiated yet
            raise AttributeError(name)
        return getattr(self.obj, name)

    def __call__(self, *args, **kwargs):
        if self.on_ready:
            kwargs['on_ready'] = self.on_ready
        watch = Stopwatch()

        with watch.measure():
            r = self.obj(*args, **kwargs)

        if isinstance(r, types.GeneratorType):
            while True:
                try:
                    with watch.measure():
                        v = next(r)
                    yield v, watch.pop()
                except StopIteration:
                    break
        else:
            yield r, watch.pop()

    def check(self):
        fn = getattr(self.obj, '__call__', None) or getattr(self.obj, 'run', None)

        if not fn:
            raise CQueueExecutionFailure('remote object has no `run` or `__call__` method')


class Stopwatch(object):
    def __init__(self):
        self.total = 0

    @contextmanager
    def measure(self):
        start = monotime()
        yield
        self.total = monotime() - start

    def pop(self):
        r = self.total
        self.total = 0
        return r
