from __future__ import absolute_import

import errno
import socket
import os
import six
import sys
import signal
import string
import threading
import tempfile

from ya.skynet.services.cqudp import cfg
from ya.skynet.services.cqudp.utils import (
    Coverage, AtFork, reconfigure_log,
    run_daemon, short, genuuid,
    poll_select, sleep, bytesio
)
from ya.skynet.library.auth.sign import FileKeysSignManager, ChainSignManager
from ya.skynet.library.auth.sshagent import SshAgent
from ya.skynet.services.cqudp.utils import log as root
from ya.skynet.services.cqudp.exceptions import CQueueRuntimeError, CQueueExecutionError, CQueueExecutionFailure
from ya.skynet.services.cqudp.rpc import gRPCDispatcher
from ya.skynet.services.cqudp.rpc.sessionauth import RemoteSessionSignManager
from ya.skynet.services.cqudp.server.exec_functions import _find_function
from ya.skynet.util.pickle import Pickler, dumps, loads
from ya.skynet.util.errors import saveTraceback, setTraceback, getTraceback
from ya.skynet.util.logging import MessageAdapter
from ya.skynet.util.net.socketstream import SocketStream


SPAWN_TIMEOUT = 20


class ProcessHandle(object):
    def __init__(self, taskid, task, tempdir, loglevel=None, log=None):
        self.taskid = taskid
        self.task = task
        self.task_ready = threading.Event()
        if task is not None:
            self.task_ready.set()

        self.task_tempdir = tempdir

        self.pid = None
        self.returncode = None
        self.loglevel = loglevel or None
        self.log = log or MessageAdapter(
            root().getChild('child'),
            fmt="[%(uuid)s] %(message)s",
            data={'uuid': short(self.taskid)},
        )

    def run_with_sock(self, sock):
        try:
            self.sock = sock
            self._prepare()
            with Coverage():
                self.stream = SocketStream(self.sock)
                self.send('init', True)
                self._slave_loop()
        except:
            self.log.exception("exception in slave: ")
            os._exit(1)
        else:
            os._exit(0)

    def fork(self):
        pipes = socket.socketpair()

        with AtFork():
            self.pid = os.fork()

        if not self.pid:
            try:
                self.sock = pipes[1]
                self._prepare()
                with Coverage():
                    self.stream = SocketStream(self.sock)
                    self._slave_loop()
            except:
                self.log.exception("exception in slave: ")
                os._exit(1)
            else:
                os._exit(0)
        else:
            self.pid_lock = threading.Lock()
            self.pid_result = None
            self.sock = pipes[0]
            self.stream = SocketStream(self.sock)

            self.log.info('started child process pid %s', self.pid)

    def start(self):
        try:
            datatype, handshake = self.recv(SPAWN_TIMEOUT)
            handshake = loads(handshake)
            if datatype != 'init' or handshake is not True:  # exception occurred during initializaion
                handshake = handshake[1]
                if isinstance(handshake, CQueueExecutionError):
                    e = handshake.error
                else:
                    e = handshake
                self.log.error("cannot spawn the task: %r", e)
                self.log.error("Traceback: %s", getTraceback(e))

                exc = CQueueExecutionFailure(e)
                setTraceback(exc, getTraceback(e))
                raise exc
        except Timeout:
            self.log.error("spawn timed out")
            raise CQueueExecutionFailure('task process startup timed out')

    def join(self, timeout):
        if self.returncode is None:
            result = self._eintr_waitpid()
            while self.returncode is None and result[0] != self.pid and (timeout is None or timeout > 0):
                if timeout is None:
                    sleep(0.1)
                else:
                    t = min(0.1, timeout)
                    sleep(t)
                    timeout -= t
                result = self._eintr_waitpid()

            if result[0] != self.pid:
                return None

            if os.WIFSIGNALED(result[1]):
                self.returncode = -os.WTERMSIG(result[1])
            else:
                self.returncode = os.WEXITSTATUS(result[1])

        return self.returncode

    def recv(self, timeout):
        self._check_signalled()

        if not self._poll(timeout):
            raise Timeout("process stopped responding")

        try:
            datatype, data = loads(self.stream.readBEStr())
        except socket.error as e:
            self.log.error("connection to the child has been lost: %s", e)
            if e.errno == errno.ECONNRESET:
                self.join(0.5)
                self._check_signalled()
            self.terminate()
            raise CommunicationError(e)

        return datatype, data

    def send(self, datatype, data):
        self.stream.writeBEStr(dumps((datatype, dump_obj(data))))

    def terminate(self):
        try:
            with self.pid_lock:
                if self.pid_result is None:
                    os.kill(self.pid, signal.SIGTERM)
        except EnvironmentError as e:
            if e.errno != errno.ESRCH:
                raise
        else:
            self.join(0.1)

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.terminate()

    def _prepare(self):
        """Prepare execution"""
        reconfigure_log('child', level=self.loglevel)  # TODO provide debug flag if needed
        gRPCDispatcher()._send = self._send_rpc
        tempfile.tempdir = self.task_tempdir
        os.environ['CQ_IMPLEMENTATION'] = 'cqudp'
        os.environ['CQUDP_TASKID'] = self.taskid
        os.environ['CQUDP_LOGLEVEL'] = self.loglevel or ''
        os.environ['CQUDP_TEMPDIR'] = self.task_tempdir
        os.environ['CQUDP_PARENTFD'] = str(self.sock.fileno())
        os.environ['CQUDP_PARENTFAMILY'] = str(self.sock.family)
        os.environ['CQUDP_PARENTTYPE'] = str(self.sock.type)

    def _send_rpc(self, datatype, msg):
        msg = rpc_msg(self.taskid, datatype[1], datatype[2], msg)
        self.send(datatype, msg)

    def _eintr_waitpid(self):
        while True:
            try:
                with self.pid_lock:
                    if self.pid_result is not None:
                        return self.pid_result

                    pid_result = os.waitpid(self.pid, os.WNOHANG)
                    if pid_result[0] == self.pid:
                        self.pid_result = pid_result
                    return pid_result
            except EnvironmentError as e:
                if e.errno == errno.EINTR:
                    continue
                raise

    def _check_signalled(self):
        code = self.join(0)

        if code and code < 0:
            self.log.error("child died with signal %s", -code)
            raise Signalled(-code)

    def _slave_receiver_thread(self):
        while True:
            if not self._poll(1.):
                continue

            try:
                datatype, data = loads(self.stream.readBEStr())
            except socket.error as e:
                self.log.error("connection to the parent has been lost: %s", e)
                continue

            if datatype == 'rpc':
                try:
                    initiator, data = loads(data)
                    gRPCDispatcher().dispatch(initiator, data)
                except Exception as e:
                    self.log.exception('exception on dispatching: %s', e)
            elif datatype == 'init':
                try:
                    self.task = loads(data)
                    self.task_ready.set()
                except Exception as e:
                    self.log.exception('exception on settings task: %s', e)
            else:
                self.log.warning('unexpected incoming datatype: %s', datatype)

    def target(self):
        def on_spawn():
            self.send('init', True)

        def on_ready():
            self.send('ready', True)

        try:
            forward_agent = self.task['options'].get('forward_agent', False)
            # TODO: stop ssh_agent
            self.ssh_agent = create_agent() if forward_agent else None
        except Exception as e:
            saveTraceback(e)
            exc = RuntimeError("failed to create SSH agent: {}".format(e))
            setTraceback(exc, getTraceback(e))
            raise exc

        return _find_function(self.task, on_spawn, on_ready)

    def _slave_loop(self):
        try:
            try:
                self.log.info('slave started')
                for envvar in ('PYTHONDONTWRITEBYTECODE', 'PYTHONNOUSERSITE'):
                    if envvar in os.environ:
                        del os.environ[envvar]
                run_daemon(self._slave_receiver_thread)
                if not self.task_ready.wait(timeout=SPAWN_TIMEOUT):
                    raise RuntimeError("failed to get task from taskhandle")
                r = self.target()
                while True:
                    res = next(r)
                    self.send('result', (res, None))
            except Exception as e:
                if not isinstance(e, StopIteration):
                    self.log.exception('slave failed: %s', e, exc_info=sys.exc_info())
                    saveTraceback(e)
                e = CQueueExecutionError(e)
                self.send('result', (None, e))
            except KeyboardInterrupt:
                raise
            except BaseException as e:
                self.log.exception("slave failed: %s", e, exc_info=sys.exc_info())
                saveTraceback(e)
                e = CQueueExecutionError(e)
                self.send('result', (None, e))
        except NonpicklableResult as e:
            saveTraceback(e)
            e.__class__ = CQueueRuntimeError  # FIXME may be we should just expose the type to API
            self.send('result', (None, e))
        finally:
            self.send('_finish', None)
            self.log.info('slave finished')

    def _poll(self, timeout):
        return bool(poll_select([self.sock.fileno()], [], [], timeout)[0])


# FIXME: don't copy-paste from client.client
def create_signer():
    fksm = FileKeysSignManager(
        commonKeyDirs=cfg.client.Auth.CommonKeyDirs,
        userKeyDirs=cfg.client.Auth.UserKeyDirs,
        keyFiles=cfg.client.Auth.KeyFiles
    )
    fksm.load()

    sm = RemoteSessionSignManager()

    return ChainSignManager([fksm, sm])


def create_agent():
    signer = create_signer()
    agent = SshAgent(signer)
    agent.start()
    os.environ['SSH_AUTH_SOCK'] = agent.socketPath
    return agent


def fix_main_module_name(obj):
    if getattr(obj, '__class__', None) is not None:  # New style class
        if obj.__class__.__module__ == '__new__main__':
            obj.__class__.__module__ = '__main__'
    elif obj.__module__ == '__new__main__':
        obj.__module__ = '__main__'
    return None


def dump_obj(obj):
    io = bytesio()
    pickler = Pickler(io)
    pickler.persistent_id = fix_main_module_name
    pickler.dump(obj)
    result = io.getvalue()

    # we need to check if we'll be able to unpickle the object at remote side
    try:
        loads(result)
    except Exception as e:
        if six.PY2:
            preview = [x for x in result if x in string.printable]
        else:
            preview = [x for x in map(chr, result) if x in string.printable]
        raise NonpicklableResult("Cannot unpickle remote result: {}({!r}). Data preview: {!r}".format(e.__class__, e, preview))

    return result


def rpc_msg(uuid, rpctype, rpcid, data):
    return {
        'uuid': genuuid(),
        'taskid': uuid,
        'data': data,
        'type': 'rpc',
        'rpctype': rpctype,
        'rpcid': rpcid,
    }


class Signalled(CQueueRuntimeError):
    allow_unpickle = True

    def __init__(self, sig=0):
        super(Signalled, self).__init__('exit by signal %s' % str(sig))
        self.signal = sig


class CommunicationError(CQueueRuntimeError):
    allow_unpickle = True

    def __init__(self, exc):
        super(CommunicationError, self).__init__('connection to child has been lost: {}'.format(exc))


class Timeout(CQueueRuntimeError):
    allow_unpickle = True

    def __init__(self, reason='timed out'):
        super(Timeout, self).__init__(reason)


class NonpicklableResult(CQueueRuntimeError):
    allow_unpickle = True
