from __future__ import absolute_import, print_function, division

import sys
import time
import errno
import socket
import logging
import msgpack
import threading
import Queue as queue

import gevent
import gevent.lock
import gevent.event
import gevent.queue
import gevent.socket

from .. import utils

from ..joint import utils as joint_utils
from ..joint import errors as joint_errors
from ..joint import socket as joint_socket

assert queue.Empty == gevent.queue.Empty


class BaseReactor(object):
    _event_class = None
    _lock_class = None
    _socket_maker = None

    class AbortError(Exception):
        """ Special class to request reactor abort. """
        pass

    def __init__(self, cfg):
        super(BaseReactor, self).__init__()

        self.jid = 0
        self.cfg = cfg
        self.age = None
        self.sid = None
        self.idle_timeout = self.cfg.client.idle_timeout

        self._jobs = {}
        self._sock = None
        self._lock = self._lock_class()
        self._active = self._event_class()

    def _make_socket(self, inet=True):
        sock = self._socket_maker(socket.AF_INET | socket.AF_INET6 if inet else socket.AF_UNIX)
        if inet:
            sock.nodelay = self.cfg.client.socket_nodelay
        sock.send_buffer = self.cfg.client.socket_send_buffer
        sock.receive_buffer = self.cfg.client.socket_receive_buffer
        return sock

    def connect(self, host, port, age):
        sock = self._make_socket(port is not None)
        sock.connect(host, port, timeout=self.cfg.client.connect_timeout)

        h = joint_utils.ConnectionHandler(self.cfg, age)
        try:
            # Send our version and magic for validation
            if not h.send_greetings(sock):
                raise joint_errors.HandshakeTimeout('Timeout while sending greetings')
            # The server should respond with session ID
            self.age, sid = h.handle_session(sock)
        except (joint_socket.EOF, socket.error) as ex:
            if isinstance(ex, socket.error) and ex.errno != errno.ECONNRESET:
                raise
            raise joint_errors.HandshakeError('Server closed the connection on handshake.')

        if sid is None:
            raise joint_errors.HandshakeTimeout('Handshake timed out: we were attemping to read')

        self.sid = joint_utils.sid2str(sid)
        self._sock = sock

    def stop(self):
        if self._sock:
            self._sock.close()
        self._active.clear()

    def start(self):
        self._active.wait()

    @property
    def active(self):
        return self._active.isSet()

    @property
    def sock_fd(self):
        """ The method was created only for debug purpose. """
        return self._sock.sock.fileno()

    def _make_job(self, jid):
        return RPCJob(jid)

    def call(self, name, args, kwargs):
        with self._lock:
            self.jid += 1
            job = self._jobs[self.jid] = self._make_job(self.jid)
            try:
                # TODO: timeout
                if name is None:  # Ping message actually
                    self._sock.write(msgpack.dumps(("PING", self.jid)))
                else:
                    self._sock.write(msgpack.dumps(("CALL", self.jid, (name, args, kwargs)), use_bin_type=True))
            except joint_socket.EOF:
                self.join(timeout=self.cfg.client.reactor_stop_timeout_after_eof)
                self._jobs.pop(self.jid, None)
                raise
            except:
                self._jobs.pop(self.jid, None)
                raise

            return job

    def feedback(self, jid, value):
        if isinstance(value, self.AbortError):
            self.abort(joint_errors.RPCError(
                "{%s:%d} Job requested reactor abort: %s", self.sid, jid, str(value)
            ))
        else:
            self._sock.write(msgpack.dumps(("FEEDBACK", jid, value), use_bin_type=True))

    def ping(self):
        job = self.call(None, None, None)
        try:
            msg = job.queue.get(timeout=self.cfg.client.ping_wait_time)
            if msg[0] != 'PONG':
                raise joint_errors.ProtocolError('Unexpected message response %r' % msg, self.sid, job.id)
        finally:
            with self._lock:
                del self._jobs[job.id]
        pass

    def drop(self, jid):
        with self._lock:
            job = self._jobs.pop(jid, None)
            if not job:
                return

            try:
                # TODO: timeout
                self._sock.write(msgpack.dumps(('DROP', jid)))
            except:
                pass
            if self.idle_timeout is None and not self._jobs:
                self.stop()

    def abort(self, message):
        self.stop()
        for jid in self._jobs.keys():
            job = self._jobs.pop(jid)
            job.queue.put(message)

    def _loop(self):
        self._active.set()
        last_message, last_ping, ping_sent = 0, time.time(), False

        while True:
            if not ping_sent:
                self._sock.timeout = self.cfg.client.ping_tick_time
            else:
                self._sock.timeout = self.cfg.client.ping_wait_time

            try:
                for message in self._sock.read_msgpack(self.cfg.client.receive_buffer):
                    last_message, ping_sent = time.time(), False
                    if message[0] == 'PONG' and not message[1]:
                        self._sock.timeout = self.cfg.client.ping_tick_time
                        continue

                    if message[0] == 'ERROR':
                        for job in self._jobs.values():
                            job.queue.put(message)
                    elif message[0] in ('REGISTERED', 'STATE', 'FAILED', 'COMPLETE', 'PONG'):
                        jid = message[1]
                        if jid not in self._jobs:
                            logging.getLogger('rpc.client.reactor').warning(
                                "{%s:%d} There's no job registered " +
                                "(probably it was dropped already on client-side).",
                                self.sid, jid
                            )
                        else:
                            self._jobs[jid].queue.put(message)
                    else:
                        raise joint_errors.ProtocolError('Unknown message %r' % message, self.sid)

                    if last_message - last_ping > self.cfg.client.ping_tick_time:
                        self._sock.write(msgpack.dumps(('PING',)))
                        last_ping, ping_sent = time.time(), True

                if ping_sent:
                    self.abort(joint_errors.RPCError('Server has gone away.', self.sid))
                    return

                if not self._jobs and time.time() - last_message > self.idle_timeout:
                    self.abort(joint_errors.RPCError('Reactor stopped because of idle timeout.', self.sid))
                    return
                else:
                    self._sock.write(msgpack.dumps(('PING',)))
                    last_ping, ping_sent = time.time(), True

            except Exception as ex:
                self.abort(ex)
                return
            except:
                self.abort(joint_errors.RPCError(
                    'reactor stopped with unknown exception of type %s: %s'
                    % sys.exc_info()[:2],
                    self.sid
                ))
                return


class ThreadedReactor(BaseReactor):
    _event_class = threading.Event
    _lock_class = staticmethod(lambda: threading.RLock(False))
    _socket_maker = lambda _, family: joint_socket.Socket(socket.socket(family, socket.SOCK_STREAM), False)

    def __init__(self, *args):
        super(ThreadedReactor, self).__init__(*args)
        self._loop_thread = None

    def stop(self):
        super(ThreadedReactor, self).stop()
        if self._loop_thread and self._loop_thread.isAlive() and threading.currentThread() != self._loop_thread:
            if self._loop_thread != threading.currentThread():
                self._loop_thread.join(timeout=10)
                assert not self._loop_thread.isAlive()

    def start(self):
        assert self._loop_thread is None
        assert self._sock is not None
        self._loop_thread = threading.Thread(target=self._loop)
        self._loop_thread.daemon = True
        self._loop_thread.start()

        super(ThreadedReactor, self).start()

    def join(self, timeout=None):
        self._loop_thread.join(timeout=timeout)
        return not self._loop_thread.isAlive()

    @property
    def active(self):
        return super(ThreadedReactor, self).active and self._loop_thread.isAlive()


class GeventReactor(BaseReactor):
    _event_class = gevent.event.Event
    _lock_class = gevent.lock.RLock
    _socket_maker = lambda _, family: joint_socket.Socket(gevent.socket.socket(family, socket.SOCK_STREAM), True)

    def __init__(self, *args):
        super(GeventReactor, self).__init__(*args)
        self._loop_grn = None

    def stop(self):
        super(GeventReactor, self).stop()
        if self._loop_grn and self._loop_grn.ready():
            self._loop_grn.join(timeout=10)
            assert not self._loop_grn.ready()

    def start(self):
        assert self._loop_grn is None
        assert self._sock is not None
        self._loop_grn = gevent.Greenlet(run=self._loop)
        self._loop_grn.start()

        super(GeventReactor, self).start()

    def _make_job(self, jid):
        return RPCJob(jid, queue_class=gevent.queue.Queue)

    def join(self, timeout=None):
        self._loop_grn.join(timeout=timeout)
        return self._loop_grn.ready()

    @property
    def active(self):
        return super(GeventReactor, self).active and not self._loop_grn.ready()


class RPCJob(object):
    def __init__(self, jid, queue_class=queue.Queue):
        self.id = jid
        self.queue = queue_class()
        self.registered = False


class RPCClientBase(object):
    AGE = 1  # Client age (or epoch or version) to be reported to server

    class Call(object):
        def __init__(self, reactor, job, reg_timeout):
            self.reactor, self.job = reactor, job
            self.timeout, self.regTimeout = None, reg_timeout
            self.lastResult = None

        def __del__(self):
            # Cleanup for forgetful users (those who did not called `wait` actually)
            self.reactor.drop(self.job.id)

        @property
        def sid(self):
            return self.reactor.sid

        @property
        def jid(self):
            return self.job.id

        def __iter__(self):
            return self

        def iter(self, timeout=None):
            self.timeout = timeout
            return self

        def next(self):
            data = self._get()
            if data[0] == 'STATE':
                return data[2]
            else:
                self.lastResult = data
                raise StopIteration()

        def send(self, feedback):
            self.reactor.feedback(self.job.id, feedback)

        @property
        def generator(self):
            for value in self:
                fb = yield value
                self.send(fb)

        def wait(self, timeout=None):
            self.timeout = timeout

            try:
                while True:
                    if self.lastResult:
                        data, self.lastResult = self.lastResult, None
                    else:
                        data = self._get()

                    if data[0] == 'STATE':
                        continue
                    elif data[0] == 'COMPLETE':
                        return data[2]
                    else:
                        raise joint_errors.ProtocolError(
                            'Bad data %r' % str(data), self.reactor.sid, self.job.id
                        )
            finally:
                self.reactor.drop(self.job.id)

        def _get(self):
            try:
                time_left = self.timeout
                if not self.job.registered:
                    time_left = min(self.regTimeout, time_left) if time_left else self.regTimeout
                if time_left < 0:
                    time_left = None

                with utils.Timer(self.timeout) as timer:
                    data = self.job.queue.get(timeout=time_left)
                    self.timeout = timer.left
            except queue.Empty:
                if not self.job.registered:
                    raise joint_errors.CallTimeout(
                        'Job registration timed out (we wait %d secs)' % self.regTimeout,
                        self.reactor.sid, self.job.id
                    )
                else:
                    raise joint_errors.CallTimeout(
                        'Job completion wait timed out', self.reactor.sid, self.job.id
                    )

            if isinstance(data, Exception):
                raise data
            elif data[0] == 'ERROR':
                raise joint_errors.CallError(
                    'Server-side error: %s' % (data[1], ), self.reactor.sid, self.job.id
                )
            elif data[0] == 'REGISTERED':
                self.job.registered = True
                if self.job.id != data[1]:
                    self.reactor.changeJobUid(self.job, data[1])
                return self._get()
            elif data[0] == 'FAILED':
                mod_name = data[2]
                cls_name = data[3]
                try:
                    # Try to reconstruct the original exception class if its visible.
                    cls = getattr(sys.modules[mod_name], cls_name)
                    ex_cls = type(cls_name, (cls, joint_errors.ServerError), dict(
                        __init__=joint_errors.ServerError.__init__
                    ))
                except (KeyError, AttributeError, TypeError):
                    ex_cls = joint_errors.ServerError
                raise ex_cls(*data[4], sid=self.reactor.sid, jid=self.job.id, module=mod_name, cls=cls_name, tb=data[5])
            elif not self.job.registered:
                self.send(BaseReactor.AbortError('Packets sequence broken.'))
                raise joint_errors.ProtocolError(
                    'Expected registration rpc message, got %r instead' % (data[0], ),
                    self.reactor.sid, self.job.id
                )

            return data

    reactors = {}
    _reactors_lock = None
    _reactor_class = None

    def __init__(self, cfg, host, port):
        self.log = logging.getLogger('rpc.client')
        self.cfg = cfg
        self.__host = host
        self.__port = port

        if self.__port is not None:
            self._reactor_key = '%s:%d' % (self.__host, self.__port)
        else:
            self._reactor_key = 'unix:%s' % (self.__host,)

    def __make_reactor(self):
        reactor = self.__class__.reactors[self._reactor_key] = self._reactor_class(self.cfg)
        self.log.debug('New reactor %r spawned.', self._reactor_key)

        reactor.connect(self.__host, self.__port, self.AGE)
        reactor.start()
        self.log.info('Reactor %r of age %r connected on fd %r', self._reactor_key, reactor.age, reactor.sock_fd)
        return reactor

    @property
    def host(self):
        return self.__host

    @property
    def port(self):
        return self.__port

    def connect(self):
        reactor = self.__class__.reactors.get(self._reactor_key)
        if not reactor or not reactor.active:
            with self.__class__._reactors_lock:
                reactor = self.__class__.reactors.get(self._reactor_key)
                if not reactor or not reactor.active:
                    reactor = self.__make_reactor()
        return reactor

    def disconnect(self):
        reactor = self.__class__.reactors.pop(self._reactor_key, None)
        if reactor:
            reactor.idle_timeout = None

    def ping(self):
        self.connect().ping()

    def call(self, name, *args, **kwargs):
        reactor = self.connect()
        return RPCClientBase.Call(reactor, reactor.call(name, args, kwargs), self.cfg.client.job_registration_timeout)

    def stop(self):
        reactor = self.__class__.reactors.get(self._reactor_key)
        if reactor:
            reactor.stop()


class RPCClient(RPCClientBase):
    _reactors_lock = threading.Lock()
    _reactor_class = ThreadedReactor


class RPCClientGevent(RPCClientBase):
    _reactors_lock = gevent.lock.Semaphore()
    _reactor_class = GeventReactor


class BaseServiceClient(object):
    """ Wrapper for the RPC client. It retries and restores the connection implicitly. """

    SERVICE_APPEARANCE_WAIT = 60

    def __init__(self, logger):
        """
        Constructor. Do not establishes the actual connection to the service immediately -
        it will be established on demand.

        :param logger: Logger to be used for logging
        """

        self._sid = None
        self.logger = logger
        self._config = None
        self._srv = None

    def __getstate__(self):
        """ This method is created to avoid pickling problems on serializing task object on XMLRPC calls. """
        return None

    @property
    def _reactor(self):
        for _ in utils.progressive_yielder(.1, 1, self.SERVICE_APPEARANCE_WAIT, False):
            try:
                return self._srv.connect()
            except joint_errors.HandshakeError:
                pass
            except EnvironmentError as ex:
                if ex.errno not in (errno.ENOENT, errno.ECONNREFUSED, errno.EAGAIN):
                    raise
        return self._srv.connect()

    def __enter__(self):
        self.__async = True
        return self

    def __exit__(self, *args):
        self.__async = False

    def __call__(self, method, *args, **kwargs):
        reactor = self._reactor
        try:
            if reactor.sid != self._sid:
                assert self._reactor
            if method is None:  # Client asked to just establish a connection.
                return
            censored = kwargs.pop("__censored", None)
            c = self._srv.call(method, *args, **kwargs)
            if censored:
                self.logger.debug("{%s:%s} Calling remote method %s(...)", c.sid, c.jid, method)
            else:
                self.logger.debug("{%s:%s} Calling remote method %s(%r, %r)", c.sid, c.jid, method, args, kwargs)
            return c.wait()
        except joint_errors.Reconnect:
            self._sid = None
            self.logger.info("{%s} Server asked to reconnect it.", self._sid)
        self._srv.disconnect()
        return self(method, *args, **kwargs)

    def ping(self, obj):
        return self("ping", obj)

    def shutdown(self):
        return self("shutdown")
