from __future__ import absolute_import, print_function, division

import gevent
try:
    import gevent.coros as coros
except ImportError:
    import gevent.lock as coros
import gevent.event
import gevent.socket
import logging
import msgpack
import random
import socket
import struct
import threading
import time

import six

from gevent.queue import Queue as GQueue, Empty as GQueueEmpty
from six.moves.queue import Queue, Empty as QueueEmpty


from kernel.util import daemonthr

from ..daemon.config import loadConfig
from ..utils.socket import Socket, EOF


assert QueueEmpty == GQueueEmpty


class RPCError(Exception):
    pass


class RPCConnectionError(RPCError, EnvironmentError):
    pass


class RPCSocketError(RPCError):
    def __init__(self):
        super(RPCSocketError, self).__init__()
        self.family_names = {
            socket.AF_INET6: 'inet6',
            socket.AF_INET: 'inet4',
            socket.AF_UNIX: 'unix'
        }

    def add(self, sock_family, err):
        try:
            name = self.family_names[sock_family]
        except KeyError:
            name = sock_family
        self.message += 'family(%s): %r\n' % (name, err)

    def __str__(self):
        return str(self.message)


class ConnectionTimeout(RPCConnectionError):
    pass


class HandshakeError(RPCConnectionError):
    pass


class HandshakeTimeout(HandshakeError, ConnectionTimeout):
    pass


class CallError(RPCError):
    pass


class CallProtocolError(RPCConnectionError):
    pass


class CallTimeout(CallError, ConnectionTimeout):
    pass


class BaseReactor(object):
    magicPacker = struct.Struct('!I')
    flagsPacker = struct.Struct('!I')

    _eventClass = None
    _lockClass = None
    _socketClass = None
    _resolveFn = None

    def __init__(self, cfg):
        super(BaseReactor, self).__init__()

        self.cfg = cfg
        self.active = self._eventClass()

        self._sock = None
        self._lock = self._lockClass()
        self._jobs = {}

    def _makeSocket(self, host, port):
        sock = Socket(self._socketClass(socket.AF_INET if port is not None else socket.AF_UNIX, socket.SOCK_STREAM))
        if port is not None:
            sock.nodelay = self.cfg.client.socket_nodelay
        sock.sendBuffer = self.cfg.client.socket_send_buffer
        sock.receiveBuffer = self.cfg.client.socket_receive_buffer
        return sock

    def _makeConnectedSocket(self, host, port):
        err = RPCSocketError()
        families = [socket.AF_INET6, socket.AF_INET] if port is not None else [socket.AF_UNIX]
        for family in families:
            try:
                sock = Socket(self._socketClass(family, socket.SOCK_STREAM))
                if port is not None:
                    sock.nodelay = self.cfg.client.socket_nodelay
                sock.sendBuffer = self.cfg.client.socket_send_buffer
                sock.receiveBuffer = self.cfg.client.socket_receive_buffer
                if port is None:
                    sock.connect(host, port, timeout=self.cfg.client.connect_timeout)
                    return sock
                else:
                    for addr in self._resolveFn(host, port, family):
                        sock.connect(addr[4], None, timeout=self.cfg.client.connect_timeout)
                        return sock
            except socket.error as e:
                err.add(family, e)
                continue

        raise err

    def connect(self, host, port):
        sock = self._makeConnectedSocket(host, port)

        if sock.write(
            self.magicPacker.pack(self.cfg.connection_magic),
            timeout=self.cfg.client.connect_timeout
        ) is None:
            raise Exception('Timeout while sending magic')

        data = sock.read(self.flagsPacker.size, timeout=self.cfg.handshake_receive_timeout)
        if data is None:
            raise Exception('Handshake timed out: we were attempting to read from {}'.format(host))

        flags, = self.flagsPacker.unpack(data)
        assert flags == 0

        if sock.write(self.flagsPacker.pack(0), timeout=self.cfg.handshake_send_timeout) is None:
            raise Exception('Handshake timed out: we were attempting to send to {}'.format(host))

        self._sock = sock

    def stop(self):
        self.active.clear()

    def start(self):
        self.active.wait()

    def isActive(self):
        return self.active.isSet()

    def _makeJob(self, uid):
        return RPCJob(uid)

    def call(self, name, args, kwargs):
        with self._lock:
            for i in range(self.cfg.uid_generation_tries):
                uid = random.getrandbits(self.cfg.uid_bits)
                if uid not in self._jobs:
                    break

            if uid in self._jobs:
                raise Exception('Was not able to generate unique uid!')

            job = self._makeJob(uid)
            self._jobs[uid] = job

            try:
                # TODO: timeout
                self._sock.write(msgpack.dumps(('CALL', uid, (name, args, kwargs))))
            except EOF:
                self.join(timeout=self.cfg.client.reactor_stop_timeout_after_eof)
                raise  # pragma: no cover
            except:
                self._jobs.pop(uid)
                raise
            else:
                return job

    def drop(self, uid):
        with self._lock:
            job = self._jobs.pop(uid, None)
            if job and not job.registered:
                return

            try:
                # TODO: timeout
                self._sock.write(msgpack.dumps(('DROP', uid)))
            except:
                pass

    def abort(self, message):
        for uid in list(self._jobs.keys()):
            job = self._jobs.pop(uid)
            job.queue.put(message)

    def _loop(self):
        pingSent = False
        lastMessage = 0

        self.active.set()
        try:
            while True:
                if not pingSent:
                    self._sock.timeout = self.cfg.client.ping_tick_time
                else:
                    self._sock.timeout = self.cfg.client.ping_wait_time

                try:
                    for message in self._sock.readMsgpack(self.cfg.client.receive_buffer):
                        if message[0] == 'PONG' and pingSent:
                            pingSent = False
                            self._sock.timeout = self.cfg.client.ping_tick_time
                            continue

                        lastMessage = time.time()

                        if message[0] == 'ERROR':
                            for job in list(self._jobs.values()):
                                job.queue.put(message)
                        elif message[0] in ('REGISTERED', 'STATE', 'FAILED', 'COMPLETE'):
                            uid = message[1]
                            self._jobs[uid].queue.put(message)
                        else:
                            raise Exception('Unknown message %r' % (message, ))

                except socket.error as ex:
                    if str(ex) != 'timed out' or pingSent:
                        self._sock.close()
                        self.active.clear()
                        self.abort(Exception('Server has gone away.'))
                        return

                    if not self._jobs and time.time() - lastMessage > self.cfg.client.idle_timeout:
                        self.abort(Exception('Idle timeout.'))
                        break
                    else:
                        self._sock.write(msgpack.dumps(('PING', )))
                        pingSent = True
                else:
                    break
        finally:
            try:
                self.abort(Exception('reactor stopped.'))
            finally:
                self._sock.close()


class ThreadedReactor(BaseReactor):
    _eventClass = threading.Event
    _lockClass = threading.Lock
    _socketClass = socket.socket

    def __init__(self, cfg):
        super(ThreadedReactor, self).__init__(cfg)
        self._loopThread = None

    def _resolveFn(self, *args, **kwargs):
        return socket.getaddrinfo(*args, **kwargs)

    def stop(self):
        if self._loopThread is None:
            return
        self._sock.close()
        self._loopThread.join(timeout=10)
        assert not self._loopThread.isAlive()

        super(ThreadedReactor, self).stop()

    def start(self):
        assert self._loopThread is None
        assert self._sock is not None
        self._loopThread = threading.Thread(target=self._loop)
        self._loopThread.daemon = True
        self._loopThread.start()

        super(ThreadedReactor, self).start()

    def join(self, timeout=None):
        self._loopThread.join(timeout=timeout)
        return not self._loopThread.isAlive()

    def isActive(self):
        return super(ThreadedReactor, self).isActive() and self._loopThread.isAlive()


class GeventReactor(BaseReactor):
    _eventClass = gevent.event.Event
    _lockClass = coros.Semaphore
    _socketClass = gevent.socket.socket

    def __init__(self, cfg):
        super(GeventReactor, self).__init__(cfg)
        self._loopGrn = None

    def _resolveFn(self, *args, **kwargs):
        return _raw_resolve(*args, **kwargs)

    def stop(self):
        if self._loopGrn is None:
            return
        self._sock.close()
        self._loopGrn.join(timeout=10)
        assert self._loopGrn.ready()

        super(GeventReactor, self).stop()

    def start(self):
        assert self._loopGrn is None
        assert self._sock is not None
        self._loopGrn = gevent.Greenlet(run=self._loop)
        self._loopGrn.start()

        super(GeventReactor, self).start()

    def _makeJob(self, uid):
        return RPCJob(uid, queueClass=GQueue)

    def join(self, timeout=None):
        self._loopGrn.join(timeout=timeout)
        return self._loopGrn.ready()

    def isActive(self):
        return super(GeventReactor, self).isActive() and not self._loopGrn.ready()


class RPCJob(object):
    def __init__(self, uid, queueClass=Queue):
        self.uid = uid
        self.queue = queueClass()
        self.registered = False


class RPCClientBase(object):
    reactors = {}

    def __init__(self, host, port, cfg=None):
        self.log = logging.getLogger('skynet.copierng.rpc.client')
        if cfg is None:
            self.cfg = loadConfig().rpc
        else:
            self.cfg = cfg
        self.__host = host
        self.__port = port

        if self.__port is not None:
            self._reactorKey = '%s:%d' % (self.__host, self.__port)
        else:
            self._reactorKey = 'unix:%s' % (self.__host, )

    def __makeReactor(self):
        reactor = self._reactorClass(self.cfg)
        reactor.connect(self.__host, self.__port)

        RPCClient.reactors[self._reactorKey] = reactor
        reactor.start()

    def call(self, name, *args, **kwargs):
        if self._reactorKey not in self.reactors or not self.reactors[self._reactorKey].isActive():
            with self._reactorsLock:
                if self._reactorKey not in self.reactors or not self.reactors[self._reactorKey].isActive():
                    if self._reactorKey in self.reactors:
                        self.reactors[self._reactorKey].stop()

                    self.__makeReactor()
            assert self._reactorKey in self.reactors

        stateCallback = kwargs.pop('stateCallback', None)

        reactor = self.reactors[self._reactorKey]
        job = reactor.call(name, args, kwargs)

        try:
            while True:
                regTimeout = self.cfg.client.job_registration_timeout
                try:
                    data = job.queue.get(timeout=regTimeout)
                except QueueEmpty:
                    if not job.registered:
                        raise Exception('Job registration timed out (we wait %d secs)' % (regTimeout, ))
                    continue

                if isinstance(data, Exception):
                    raise data
                elif data[0] == 'ERROR':
                    raise Exception('Server-side error: %s' % (data[1], ))
                elif data[0] == 'REGISTERED':
                    job.registered = True
                    if job.uid != data[1]:
                        reactor.changeJobUid(job, data[1])
                else:
                    if not job.registered:
                        raise Exception('Expected registration rpc message, got %r instead' % (data, ))

                    if data[0] == 'STATE':
                        if stateCallback is not None:
                            stateCallback(data[2])
                    elif data[0] == 'COMPLETE':
                        return data[2]
                    elif data[0] == 'FAILED':
                        raise Exception('Server-side error: %s' % (data[2], ))
                    else:
                        raise Exception('Bad data %r' % (data, ))
        finally:
            reactor.drop(job.uid)


class RPCClient(RPCClientBase):
    _reactorsLock = threading.Lock()
    _reactorClass = ThreadedReactor


class RPCClientGevent(RPCClientBase):
    _reactorsLock = coros.Semaphore()
    _reactorClass = GeventReactor


def _raw_resolve(host, port, family):
    import os
    r, w = os.pipe()

    result = []

    def resolve():
        import _socket
        # noinspection PyBroadException
        try:
            result.append((_socket.getaddrinfo(host, port, family), None))
        except:
            import sys
            result.append((None, sys.exc_info()))

        os.write(w, '\xff')

    try:
        daemonthr(resolve)

        gevent.socket.wait_read(r)

        res, err = result[0]
        if err:
            six.reraise(err[0], err[1], err[2])
        return res
    finally:
        os.close(r)
        os.close(w)
