# vim: foldmethod=marker

from __future__ import absolute_import, print_function, division

import socket
try:
    import gevent
    import gevent.coros
except ImportError:
    gevent = None  # NOQA
import msgpack
import errno
import time
import threading


class EOF(Exception):
    pass


class Socket(object):
    class ScopedTimeout(object):
        def __init__(self, socket, timeout):
            self.__socket = None
            if timeout is not None:
                self.__socket = socket
                self.__timeout = socket.timeout
                socket.timeout = timeout

        def __enter__(self):
            pass

        def __exit__(self, type, value, traceback):
            if self.__socket:
                self.__socket.timeout = self.__timeout

    class _noLock(object):
        def __enter__(self):
            return self

        def __exit__(self, type, value, traceback):
            pass

    def __init__(self, sock, geventMode=False):
        self.sock = sock
        self.__geventMode = geventMode
        self.__closed = False

        if geventMode:
            self.__lock = gevent.coros.Semaphore(1)
        elif geventMode is not None:
            self.__lock = threading.Lock()
        else:
            self.__lock = self._noLock()

        sockname = sock.getsockname()
        if sockname == '':
            self.__addr = ('', '')  # socketpair?
            self.__binded = True
        elif sockname[1] == 0:
            self.__addr = None
            self.__binded = False
        else:
            self.__addr = sock.getsockname()
            self.__binded = True

        try:
            self.__peer = sock.getpeername()
            self.__connected = True
        except socket.error as ex:
            if ex.errno == errno.ENOTCONN:
                self.__peer = None
                self.__connected = False
            else:
                raise

        self.__msgpack_unpacker = msgpack.Unpacker(use_list=True)

    # Properties and shortcuts {{{
    @property
    def addr(self):
        assert self.__binded or self.__connected, 'Bind first'
        if self.__addr is None:
            self.__addr = self.sock.getsockname()
        return self.__addr

    @property
    def peer(self):
        assert self.__connected, 'Connect first'
        if self.__peer is None:
            self.__peer = self.sock.getpeername()

        return self.__peer

    @property
    def nodelay(self):
        return self.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)

    @nodelay.setter
    def nodelay(self, value):
        self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value)

    @property
    def sendBuffer(self):
        return self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)

    @sendBuffer.setter
    def sendBuffer(self, value):
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, value)

    @property
    def receiveBuffer(self):
        return self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF)

    @receiveBuffer.setter
    def receiveBuffer(self, value):
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, value)

    @property
    def timeout(self):
        return self.sock.gettimeout()

    @timeout.setter
    def timeout(self, value):
        self.sock.settimeout(value)

    # Properties and shortcuts }}}

    def close(self, shutdown=socket.SHUT_RDWR):
        with self.__lock:
            if self.__closed:
                return

            if shutdown:
                try:
                    self.sock.shutdown(shutdown)
                except:
                    pass

            self.sock.close()
            self.__closed = True

    def recv(self, readby=8192, timeout=None):
        if timeout is not None and timeout < 0:
            return None

        with Socket.ScopedTimeout(self, timeout):
            try:
                return self.sock.recv(readby)
            except socket.error as ex:
                if str(ex) == 'timed out':
                    return None
                raise

    def send(self, data, timeout=None):
        if timeout is not None and timeout < 0:
            return None

        with Socket.ScopedTimeout(self, timeout):
            try:
                return self.sock.send(data)
            except socket.error as ex:
                if ex.errno == errno.EPIPE:
                    raise EOF()
                if str(ex) == 'timed out':
                    return None
                raise

    def read(self, count, readby=0x2000, timeout=None, raiseEof=True):
        with self.__lock:
            buff = []
            left = count

            if timeout:
                deadline = time.time() + timeout

            while left:
                data = self.recv(min(readby, left), timeout=deadline - time.time() if timeout else None)

                if data is None:
                    # Timed out
                    return None

                if not len(data):
                    if raiseEof:
                        raise EOF()
                    else:
                        break

                left -= len(data)
                buff.append(data)

            return ''.join(buff)

    def write(self, data, timeout=None):
        with self.__lock:
            length = len(data)
            totalSent = 0

            if timeout:
                deadline = time.time() + timeout

            datamem = memoryview(data)

            while totalSent < length:
                sent = self.send(
                    datamem[totalSent:totalSent + 8192],
                    timeout=deadline - time.time() if timeout else None
                )
                if not sent:
                    return None
                totalSent += sent

            return totalSent

    def writeMsgpack(self, obj, timeout=None):
        return self.write(msgpack.packb(obj), timeout)

    def readMsgpack(self, readby=0x2000):
        while True:
            try:
                buf = self.recv(readby)
            except socket.error as ex:
                if ex.errno in (errno.EBADF, ):
                    return
                raise

            if not buf:
                # No more data to process here, connection closed or timed out
                return

            self.__msgpack_unpacker.feed(buf)
            for data in self.__msgpack_unpacker:
                stop = yield data
                if stop:
                    yield 0
                    return

    def bind(self, host, port):
        with self.__lock:
            assert not self.__binded
            self.sock.bind((host, port))
            self.__binded = True

    def connect(self, host, port, timeout=None):
        with self.__lock:
            assert not self.__connected

            with Socket.ScopedTimeout(self, timeout):
                if port is None:
                    self.sock.connect(host)
                else:
                    self.sock.connect((host, port))
                self.__connected = True
