from __future__ import absolute_import, print_function, division

import cPickle
import logging
import msgpack
import socket
import sys
import threading
import time

try:
    import gevent

    try:
        import gevent.coros as coros
    except ImportError:
        import gevent.lock as coros

    import gevent.event
    import gevent.socket

    from gevent.queue import Queue as GQueue, Empty as GQueueEmpty
except ImportError:
    gevent = None

from Queue import Queue, Empty as QueueEmpty

from . import errors, MAGIC
from .socket import Socket, EOF
from .utils import Timer, ConnectionHandler, sid2str

if gevent:
    assert QueueEmpty == GQueueEmpty


class BaseReactor(object):
    HANDSHAKE_RECEIVE_TIMEOUT = 120
    HANDSHAKE_SEND_TIMEOUT = 120

    CONNECT_TIMEOUT = 120
    SOCKET_NODELAY = True
    SOCKET_RECEIVE_BUFFER = 16 * 1024
    SOCKET_SEND_BUFFER = 16 * 1024
    IDLE_TIMEOUT = 10
    PING_TICK_TIME = 30  # ping every 30 seconds
    PING_WAIT_TIME = 60  # wait for ping response for 60 seconds
    RECEIVE_BUFFER = 16 * 1024
    REACTOR_STOP_TIMEOUT_AFTER_EOF = 10

    UID_BITS = 32

    _eventClass = None
    _lockClass = None
    _socketMaker = None

    class AbortError(Exception):
        """ Special class to request reactor abort. """
        pass

    def __init__(self, log):
        super(BaseReactor, self).__init__()

        self.jid = 0
        self.active = self._eventClass()

        self._sock = None
        self._lock = self._lockClass()
        self._jobs = {}

        self.log = log.getChild('reactor')

    def _make_socket(self, inet=True):
        sock = self._socketMaker(socket.AF_INET if inet else socket.AF_UNIX)
        if inet:
            sock.nodelay = self.SOCKET_NODELAY
        sock.sendBuffer = self.SOCKET_SEND_BUFFER
        sock.receiveBuffer = self.SOCKET_RECEIVE_BUFFER
        return sock

    def connect(self, host, port):
        sock = self._make_socket(port is not None)
        sock.connect(host, port, timeout=self.CONNECT_TIMEOUT)

        h = ConnectionHandler(MAGIC, self.HANDSHAKE_SEND_TIMEOUT, self.HANDSHAKE_RECEIVE_TIMEOUT)
        # First of all, send our magic for validation
        if sock.write(h.get_magic(), timeout=self.CONNECT_TIMEOUT) is None:
            raise errors.HandshakeTimeout('Timeout while sending magic')

        try:
            # The server should respond with session ID
            sid = h.handle(sock)
        except EOF:
            raise errors.HandshakeError('Server closed the connection on handshake.')

        if sid is None:
            raise errors.HandshakeTimeout('Handshake timed out: we were attemping to read')

        self.sid = sid2str(sid, self.UID_BITS)
        self._sock = sock

    def stop(self):
        if self._sock:
            self._sock.close()
        self.active.clear()

    def start(self):
        self.active.wait()

    def is_active(self):
        return self.active.isSet()

    def _make_job(self, jid):
        return RPCJob(jid)

    def call(self, name, args, kwargs):
        with self._lock:
            self.jid += 1
            job = self._make_job(self.jid)
            try:
                self._jobs[job.id] = job
                try:
                    # TODO: timeout
                    if name is None:  # Ping message actually
                        self._sock.write(msgpack.dumps(('PING', job.id)))
                    else:
                        self._sock.write(msgpack.dumps(('CALL', job.id, (name, args, kwargs))))
                except:
                    self._jobs.pop(job.id)
                    raise
            except EOF:
                self.join(timeout=self.REACTOR_STOP_TIMEOUT_AFTER_EOF)
                raise  # pragma: no cover

            return job

    def feedback(self, jid, value):
        if isinstance(value, self.AbortError):
            self.abort(errors.RPCError("[%s.%d] Job requested reactor abort: %s", self.sid, jid, str(value)))
        else:
            self._sock.write(msgpack.dumps(('FEEDBACK', jid, value)))

    def name(self, name):
        self._sock.write(msgpack.dumps(('NAME', name)))

    def ping(self):
        job = self.call(None, None, None)
        try:
            msg = job.queue.get(timeout=self.PING_WAIT_TIME)
            if msg[0] != 'PONG':
                raise errors.ProtocolError('Unexpected message response %r' % msg, self.sid, job.id)
        finally:
            with self._lock:
                del self._jobs[job.id]
        pass

    def drop(self, jid):
        with self._lock:
            job = self._jobs.pop(jid, None)
            if not job:
                return

            try:
                # TODO: timeout
                self._sock.write(msgpack.dumps(('DROP', jid)))
            except:
                pass

    def abort(self, message):
        self.stop()
        for jid in self._jobs.keys():
            job = self._jobs.pop(jid)
            job.queue.put(message)

    def _loop(self):
        ping_sent = False
        last_message = 0

        self.active.set()

        def _tune_timeout(timeout=None):
            if timeout is not None:
                _tune_timeout.deadline = time.time() + timeout

            deadline = _tune_timeout.deadline

            tout = max(0, deadline - time.time())

            if tout > 0:
                self._sock.timeout = tout
                return True
            else:
                self._sock.timeout = None
                return False
        _tune_timeout.deadline = 0

        while True:
            if not ping_sent:
                # If we are not waiting any already sent ping -- make all socket ops
                # to freeze for max ping tick time seconds, so we will be able to send ping again
                _tune_timeout(self.PING_TICK_TIME)
            else:
                _tune_timeout(self.PING_WAIT_TIME)

            try:
                if _tune_timeout():
                    for message in self._sock.readMsgpack(self.RECEIVE_BUFFER):
                        if message[0] == 'PONG' and not message[1]:
                            ping_sent = False
                            if _tune_timeout(self.PING_TICK_TIME):
                                continue
                            else:
                                break

                        last_message = time.time()

                        if message[0] == 'ERROR':
                            for job in self._jobs.values():
                                job.queue.put(message)
                        elif message[0] in ('REGISTERED', 'STATE', 'FAILED', 'COMPLETE', 'PONG'):
                            jid = message[1]
                            if jid not in self._jobs:
                                self.log.warning(
                                    "[%s.%d] There's no job registered " +
                                    "(probably it was dropped already on client-side).",
                                    self.sid, jid
                                )
                            else:
                                self._jobs[jid].queue.put(message)
                        else:
                            raise errors.ProtocolError('Unknown message %r' % message, self.sid)

                        if not _tune_timeout():
                            break

                # We will be here once we are finished to processing all messages or
                if ping_sent:
                    # If we sent ping, but have no messages anymore -- server freezed somehow
                    self.abort(errors.RPCError('Server has gone away.', self.sid))
                    return

                if not self._jobs:
                    if time.time() - last_message > self.IDLE_TIMEOUT:
                        self.abort(errors.RPCError('Reactor stopped because of idle timeout.', self.sid))
                        return

                self._sock.write(msgpack.dumps(('PING', )))
                ping_sent = True

            except Exception as ex:
                self.abort(ex)
                return
            except:
                self.abort(errors.RPCError(
                    'reactor stopped with unknown exception of type %s: %s'
                    % sys.exc_info()[:2],
                    self.sid
                ))
                return


class ThreadedReactor(BaseReactor):
    _eventClass = threading.Event
    _lockClass = staticmethod(lambda: threading.RLock(False))
    _socketMaker = lambda _, family: Socket(socket.socket(family, socket.SOCK_STREAM), False)  # noqa

    def __init__(self, log):
        super(ThreadedReactor, self).__init__(log=log)
        self._loopThread = None

    def stop(self):
        super(ThreadedReactor, self).stop()
        if self._loopThread and self._loopThread.isAlive() and threading.currentThread() != self._loopThread:
            if self._loopThread != threading.currentThread():
                self._loopThread.join(timeout=10)
                assert not self._loopThread.isAlive()

    def start(self):
        assert self._loopThread is None
        assert self._sock is not None
        self._loopThread = threading.Thread(target=self._loop)
        self._loopThread.daemon = True
        self._loopThread.start()

        super(ThreadedReactor, self).start()

    def join(self, timeout=None):
        self._loopThread.join(timeout=timeout)
        return not self._loopThread.isAlive()

    def is_active(self):
        return super(ThreadedReactor, self).is_active() and self._loopThread.isAlive()


if gevent:
    class GeventReactor(BaseReactor):
        _eventClass = gevent.event.Event
        _lockClass = coros.RLock
        _socketMaker = lambda _, family: Socket(gevent.socket.socket(family, socket.SOCK_STREAM), True)  # noqa

        def __init__(self, log):
            super(GeventReactor, self).__init__(log=log)
            self._loopGrn = None

        def stop(self):
            super(GeventReactor, self).stop()
            if self._loopGrn and self._loopGrn.ready():
                self._loopGrn.join(timeout=10)
                assert not self._loopGrn.ready()

        def start(self):
            assert self._loopGrn is None
            assert self._sock is not None
            self._loopGrn = gevent.Greenlet(run=self._loop)
            self._loopGrn.start()

            super(GeventReactor, self).start()

        def _make_job(self, jid):
            return RPCJob(jid, queue_class=GQueue)

        def join(self, timeout=None):
            self._loopGrn.join(timeout=timeout)
            return self._loopGrn.ready()

        def is_active(self):
            return super(GeventReactor, self).is_active() and not self._loopGrn.ready()


class RPCJob(object):
    def __init__(self, jid, queue_class=Queue):
        self.id = jid
        self.queue = queue_class()
        self.registered = False


class RPCClientBase(object):
    JOB_REGISTRATION_TIMEOUT = 20

    class Call(object):
        def __init__(self, reactor, job, reg_timeout):
            self.reactor, self.job = reactor, job
            self.timeout, self.reg_timeout = None, reg_timeout
            self.lastResult = None

        def __del__(self):
            # Cleanup for forgetful users (those who did not called `wait` actually)
            self.reactor.drop(self.job.id)

        @property
        def sid(self):
            return self.reactor.sid

        @property
        def jid(self):
            return self.job.id

        def __iter__(self):
            return self

        def iter(self, timeout=None):
            self.timeout = timeout
            return self

        def next(self):
            data = self._get()
            if data[0] == 'STATE':
                return data[2]
            else:
                self.lastResult = data
                raise StopIteration()

        def send(self, feedback):
            self.reactor.feedback(self.job.id, feedback)

        @property
        def generator(self):
            for value in self:
                fb = yield value
                self.send(fb)

        def wait(self, timeout=None):
            self.timeout = timeout

            try:
                while True:
                    if self.lastResult:
                        data, self.lastResult = self.lastResult, None
                    else:
                        data = self._get()

                    if data[0] == 'STATE':
                        continue
                    elif data[0] == 'COMPLETE':
                        return data[2]
                    else:
                        raise errors.ProtocolError('Bad data %r' % str(data), self.reactor.sid, self.job.id)
            finally:
                self.reactor.drop(self.job.id)

        def _get(self):
            try:
                time_left = self.timeout
                if not self.job.registered:
                    time_left = min(self.reg_timeout, time_left) if time_left else self.reg_timeout
                if time_left < 0:
                    time_left = None

                with Timer(self.timeout) as timer:
                    if time_left is None:
                        # queue.get without timeout will break Ctrl-C handling
                        while True:
                            try:
                                data = self.job.queue.get(timeout=60)
                            except QueueEmpty:
                                continue
                            else:
                                break
                    else:
                        data = self.job.queue.get(timeout=time_left)
                    self.timeout = timer.counter
            except QueueEmpty:
                if not self.job.registered:
                    raise errors.Timeout(
                        'Job registration timed out (we wait %d secs)' % self.reg_timeout,
                        self.reactor.sid, self.job.id
                    )
                else:
                    raise errors.CallTimeout('Job completion wait timed out', self.reactor.sid, self.job.id)

            if isinstance(data, Exception):
                if isinstance(data, EOF):
                    raise errors.ProtocolError('Server closed connection: %s' % (data, ))
                raise data
            elif data[0] == 'ERROR':
                raise errors.CallError('Server-side error: %s' % (data[1], ), self.reactor.sid, self.job.id)
            elif data[0] == 'REGISTERED':
                self.job.registered = True
                if self.job.id != data[1]:
                    self.reactor.changeJobUid(self.job, data[1])
                return self._get()
            elif not self.job.registered:
                self.send(BaseReactor.AbortError('Packets sequence broken.'))
                raise errors.ProtocolError(
                    'Expected registration rpc message, got %r instead' % (data[0], ),
                    self.reactor.sid, self.job.id
                )
            elif data[0] == 'FAILED':
                try:
                    exc = cPickle.loads(data[2])
                except Exception as ex:
                    exc = ex
                raise errors.CallFail('Server-side error', self.reactor.sid, self.job.id, exc)

            return data

    reactors = {}
    _reactorsLock = None
    _reactorClass = None

    def __init__(self, host, port, logger=None):
        if not logger:
            self.log = logging.getLogger('rpc.client')
        else:
            self.log = logger.getChild('rpc.client')

        self.__host = host
        self.__port = port

        if self.__port is not None:
            self._reactorKey = '%s:%d' % (self.__host, self.__port)
        else:
            self._reactorKey = 'unix:%s' % (self.__host, )

    def __make_reactor(self):
        reactor = self.__class__.reactors[self._reactorKey] = self._reactorClass(log=self.log)

        reactor.connect(self.__host, self.__port)
        reactor.start()
        return reactor

    @property
    def host(self):
        return self.__host

    @property
    def port(self):
        return self.__port

    def connect(self):
        reactor = self.__class__.reactors.get(self._reactorKey)
        if not reactor or not reactor.is_active():
            with self.__class__._reactorsLock:
                reactor = self.__class__.reactors.get(self._reactorKey)
                if not reactor or not reactor.is_active():
                    reactor = self.__make_reactor()
        return reactor

    def ping(self):
        self.connect().ping()

    def call(self, name, *args, **kwargs):
        reactor = self.connect()
        return RPCClientBase.Call(reactor, reactor.call(name, args, kwargs), self.JOB_REGISTRATION_TIMEOUT)

    def name(self, name):
        reactor = self.connect()
        reactor.name(name)

    def stop(self):
        reactor = self.__class__.reactors.get(self._reactorKey)
        if reactor:
            reactor.stop()


class RPCClient(RPCClientBase):
    _reactorsLock = threading.Lock()
    _reactorClass = ThreadedReactor


if gevent:
    class RPCClientGevent(RPCClientBase):
        _reactorsLock = coros.Semaphore()
        _reactorClass = GeventReactor
