import os
import six
import time
import logging
import threading
from weakref import WeakValueDictionary

from ya.skynet.util.pickle import Pickler, loads
from ya.skynet.util.sys.user import getUserName
from .. import eggs
from ..rpc import RPCDispatcher
from ..poll import ResultsQueue, Selectable, Poll
from ..rpc.pipe import Pipe
from ..mocksoul_rpc.client import RPCClient
from ..mocksoul_rpc.errors import RPCError, ProtocolError
from ..utils import FdHolder, genuuid, poll_select, bytesio, monotime


class RpcWrapper(object):
    def __init__(self, slave, logger, gevent):
        self.slave = slave
        self.log = logger
        self.gevent = gevent
        if gevent:
            import gevent as g
            self.sleep = g.sleep
        else:
            self.sleep = time.sleep

    def call(self, *args, **kwargs):
        for i in six.moves.xrange(5):
            try:
                # TODO job.wait or job.__iter__ also can fail due to disconnect
                job = self.slave.call(*args, **kwargs)
                return job
            except ProtocolError as e:
                self.log.warning("call failed with exception: %s", e)
            except RPCError as e:
                self.log.warning("call failed with server exception: %s", e)
                raise
            self.sleep(0.5)
            self.slave.connect()

    def stop(self):
        return self.slave.stop()

    def join(self):
        return self.slave.join()


class DaemonSession(Selectable):
    _event_fd = FdHolder('event_fd')
    _event_notify_fd = FdHolder('event_notify_fd')

    def __init__(self, uuid, hosts, client, is_iter=False):
        super(DaemonSession, self).__init__()
        self._event_fd, self._event_notify_fd = None, None
        self.uuid = uuid
        self.client = client
        self.hosts = hosts
        self.rpc_dispatcher = RPCDispatcher()
        self.yield_stopiter = is_iter
        if self.client.gevent:
            import gevent.queue
            self.results = gevent.queue.Queue()
        else:
            self.results = six.moves.queue.Queue()
        self.done = False

    def _get_event_fd(self):
        if self._event_fd is None:
            self._event_fd, self._event_notify_fd = os.pipe()

        return self._event_fd

    def _notify(self):
        if self._event_notify_fd is not None:
            os.write(self._event_notify_fd, b'1')

    def _is_data_ready(self):
        return bool(self.results.qsize())

    def get_host_by_id(self, hostid):
        return self.hosts[hostid]

    @property
    def taskid(self):
        return self.uuid

    @property
    def id(self):
        return self.uuid

    @property
    def running(self):
        return not self.done

    def remoteObject(self):
        raise NotImplementedError

    def link(self, receiver=None):
        raise NotImplementedError

    @property
    def participantsFoundHandler(self):
        raise NotImplementedError

    @property
    def resultEventsHandler(self):
        raise NotImplementedError

    def dump(self, task):
        def set_object_session(obj):
            if getattr(obj, 'needs_session', None) is True:
                obj._set_session(self)
            return None

        io = bytesio()
        pickler = Pickler(io)
        pickler.persistent_id = set_object_session
        pickler.dump(task)

        return io.getvalue()

    def rpc_send(self, typ, uuid, msg, addrs, tree_data=None):
        job = self.client.rpc.call(
            'rpc_send',
            taskid=self.uuid,
            typ=typ,
            uuid=uuid,
            msg=msg,
            addrs=addrs,
            tree_data=tree_data,
        )
        if not job.wait(10):
            raise RuntimeError("Cannot send rpc")  # FIXME

    def poll(self, block=True, timeout=None):
        if self.done:
            return

        results = []
        if self.client.background_poll:
            while not self.results.empty():
                results.append(self.results.get_nowait())

            if not results and block:
                try:
                    results.append(self.results.get(block, timeout))
                except six.moves.queue.Empty:
                    pass
        else:
            results = self.client._fetch([self.uuid], block=block, timeout=timeout)

        for _, (hostid, addr), label, data in results:
            if label == six.b('result'):
                res, err = data
                if err is not None:
                    self.done = True
                    if isinstance(err, StopIteration) and not self.yield_stopiter:
                        continue
                yield self.get_host_by_id(hostid), res, err
            else:
                self.rpc_dispatcher.dispatch((hostid, addr), data)

    def wait(self, block=True, timeout=None):
        if not block:
            for item in self.poll(block=False):
                yield item
            return

        while not self.done and (timeout is None or timeout > 0):
            start = monotime()
            for item in self.poll(block=block, timeout=timeout):
                yield item

            if timeout is not None:
                timeout -= monotime() - start

    def stop(self):
        job = self.client.rpc.call(
            'stop_task',
            taskid=self.uuid,
        )
        job.wait(10)  # FIXME

    shutdown = stop

    def dispatch(self, addr, label, data):
        if label == six.b('result'):
            self.results.put((None, addr, label, data))
            self._notify()
        else:
            self.rpc_dispatcher.dispatch(addr, data)

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.stop()

    def __del__(self):
        try:
            del self._event_fd
        finally:
            del self._event_notify_fd


class DaemonClient(object):
    def __init__(self,
                 path,
                 background_poll=True,
                 gevent=False,
                 *args, **kwargs):
        self.gevent = gevent
        self.args = args
        self.kwargs = kwargs
        self.running = True
        self.worker = None
        logger = logging.getLogger('cqudp.thinclient')
        if not logger.handlers:
            logger.addHandler(logging.NullHandler())
        if gevent:
            from ..mocksoul_rpc.gevent_client import RPCClientGevent
            self.rpc = RPCClientGevent(path, None, logger=logger)
        else:
            self.rpc = RPCClient(path, None, logger=logger)

        self.rpc = RpcWrapper(self.rpc, logger=logger, gevent=gevent)

        self.sessions = WeakValueDictionary()
        self.background_poll = background_poll
        if background_poll:
            if gevent:
                import gevent as g
                self.worker = g.spawn(self._process_new)
            else:
                self.worker = threading.Thread(target=self._process_new)
                self.worker.daemon = True
                self.worker.start()

    @property
    def signer(self):
        return []

    def createPipe(self):
        if self.gevent:
            import gevent.select
            return Pipe(ResultsQueue(select_function=gevent.select.select))
        else:
            return Pipe(ResultsQueue(select_function=poll_select))

    def createPoll(self, iterable=None):
        if self.gevent:
            import gevent.select
            return Poll(iterable, select_function=gevent.select.select)
        else:
            return Poll(iterable)

    def createQueue(self, *args, **kwargs):
        raise NotImplementedError

    def ping(self, hosts, *args, **kwargs):
        raise NotImplementedError

    def _run(self, hosts, task, iter):
        description = str(task)
        modules_list = getattr(task, 'marshaledModules', None)
        egg = eggs.create_egg(modules_list)
        user = getattr(task, 'osUser', getUserName())
        uuid = six.b(genuuid())
        hosts = dict(enumerate(hosts))
        session = DaemonSession(uuid, hosts, self, is_iter=iter)
        task = session.dump(task)
        socket_path = os.getenv('SSH_AUTH_SOCK')

        job = self.rpc.call(
            'run_task',
            uuid=uuid,
            task=task,
            description=description,
            hosts=hosts,
            egg=egg,
            user=user,
            socket_path=socket_path,
            iter=iter,
            netlibus=self.kwargs.get('netlibus', True),
        ).wait(10)
        if job != uuid:
            raise RuntimeError("Cannot start the job")  # FIXME

        self.sessions[uuid] = session
        return session

    def run(self, hosts, task, params=None):
        # TODO params
        return self._run(hosts, task, False)

    def iter(self, hosts, task, params=None):
        # TODO params
        return self._run(hosts, task, True)

    def iterFull(self, *args, **kwargs):
        raise NotImplementedError

    def run_in_porto(self, *args, **kwargs):
        raise NotImplementedError

    def stats(self, *args, **kwargs):
        raise NotImplementedError

    def reloadServerKeys(self, *args, **kwargs):
        raise NotImplementedError

    def run_shell(self, *args, **kwargs):
        raise NotImplementedError

    def _fetch(self, uuids=None, block=False, timeout=None):
        for uuid, pathitem, label, data in self.rpc.call(
            'take_incoming',
            uuids=uuids or list(self.sessions.keys()),
            block=block,
            timeout=timeout,
        ):
            hostid, host = pathitem[:2]
            host = host if isinstance(host, six.string_types) else tuple(host)
            data = loads(data)
            yield uuid, (hostid, host), label, data

    def _process_new(self):
        while self.running:
            for uuid, addr, label, data in self._fetch(list(self.sessions.keys()), block=True, timeout=0.5):
                session = self.sessions.get(uuid)
                if session is None:
                    continue
                session.dispatch(addr, label, data)

    def shutdown(self):
        self.running = False
        sessions, self.sessions = list(self.sessions.values()), WeakValueDictionary()
        for session in sessions:
            session.stop()
        self.rpc.stop()
        if self.gevent and self.worker is not None:
            self.worker.kill()
            self.worker = None

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.shutdown()

    def __del__(self):
        self.shutdown()
