# vim: foldmethod=marker

from __future__ import absolute_import, print_function, division

import socket
try:
    import gevent
except ImportError:
    gevent = None  # NOQA
else:
    try:
        import gevent.coros as coros
    except ImportError:
        import gevent.lock as coros

import errno
import time
import six
import threading

from ya.skynet.util.sys.gettime import monoTime
from ya.skynet.util.sys.getpeerid import getpeerid

from .. import msgpackutils as msgpack


class EOF(Exception):
    pass


class Socket(object):
    class ScopedTimeout(object):
        def __init__(self, socket, timeout):
            self.__socket = None
            if timeout is not None:
                self.__socket = socket
                self.__timeout = socket.timeout
                self.__sock_timeout = socket.sock.gettimeout()
                socket.timeout = timeout
                socket.sock.settimeout(timeout)

        def __enter__(self):
            pass

        def __exit__(self, type, value, traceback):
            if self.__socket:
                self.__socket.timeout = self.__timeout
                self.__socket.sock.settimeout(self.__sock_timeout)

    class _noLock(object):
        def __enter__(self):
            return self

        def __exit__(self, type, value, traceback):
            pass

    def __init__(self, sock, gevent_mode=False):
        self.sock = sock
        self.__gevent_mode = gevent_mode
        self.__closed = False

        if gevent_mode:
            self.__lock = coros.Semaphore(1)
            self.__sleep = gevent.sleep
        elif gevent_mode is not None:
            self.__lock = threading.Lock()
            self.__sleep = time.sleep
        else:
            self.__lock = self._noLock()
            self.__sleep = time.sleep

        sockname = sock.getsockname()
        if sockname == b'':
            self.__addr = (b'', b'')  # socketpair?
            self.__binded = True
        elif sockname[1] == 0:
            self.__addr = None
            self.__binded = False
        else:
            self.__addr = sock.getsockname()
            self.__binded = True

        try:
            self.__peer = sock.getpeername()
            self.__connected = True
        except socket.error as ex:
            if ex.errno == errno.ENOTCONN:
                self.__peer = None
                self.__connected = False
            else:
                raise

        self.__peer_id = None
        self.__msgpack_unpacker = msgpack.Unpacker(use_list=True)

    # Properties and shortcuts {{{
    @property
    def addr(self):
        assert self.__binded or self.__connected, 'Bind first'
        if self.__addr is None:
            self.__addr = self.sock.getsockname()
        return self.__addr

    @property
    def peer(self):
        assert self.__connected, 'Connect first'
        if self.__peer is None:
            self.__peer = self.sock.getpeername()

        return self.__peer

    @property
    def peer_id(self):
        assert self.__connected, 'Connect first'
        if self.__peer_id is None:
            self.__peer_id = getpeerid(self.sock)

        return self.__peer_id

    @property
    def nodelay(self):
        return self.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)

    @nodelay.setter
    def nodelay(self, value):
        self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value)

    @property
    def send_buffer(self):
        return self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)

    @send_buffer.setter
    def send_buffer(self, value):
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, value)

    @property
    def receive_buffer(self):
        return self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF)

    @receive_buffer.setter
    def receive_buffer(self, value):
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, value)

    @property
    def timeout(self):
        return self.sock.gettimeout()

    @timeout.setter
    def timeout(self, value):
        self.sock.settimeout(value)

    # Properties and shortcuts }}}

    def close(self, shutdown=socket.SHUT_RDWR):
        with self.__lock:
            if self.__closed:
                return

            if shutdown:
                try:
                    self.sock.shutdown(shutdown)
                except:
                    pass

            self.sock.close()
            self.__closed = True

    def recv(self, readby=8192, timeout=None):
        if timeout is not None and timeout < 0:
            return None

        with Socket.ScopedTimeout(self, timeout):
            try:
                return self.sock.recv(readby)
            except socket.error as ex:
                if str(ex) == 'timed out':
                    return None
                raise

    def send(self, data, timeout=None):
        if timeout is not None and timeout < 0:
            return None

        with Socket.ScopedTimeout(self, timeout):
            try:
                return self.sock.send(data)
            except socket.error as ex:
                if ex.errno == errno.EPIPE:
                    raise EOF()
                if str(ex) == 'timed out':
                    return None
                raise

    def read(self, count, readby=0x2000, timeout=None, raise_eof=True):
        with self.__lock:
            buff = []
            left = count

            if timeout:
                deadline = monoTime() + timeout

            while left:
                data = self.recv(min(readby, left), timeout=deadline - monoTime() if timeout else None)

                if data is None:
                    # Timed out
                    return None

                if not len(data):
                    if raise_eof:
                        raise EOF()
                    else:
                        break

                left -= len(data)
                buff.append(data)

            return b''.join(buff)

    def write(self, data, timeout=None):
        with self.__lock:
            length = len(data)
            total_sent = 0

            if timeout:
                deadline = monoTime() + timeout

            while total_sent < length:
                sent = self.send(data[total_sent:], timeout=deadline - monoTime() if timeout else None)
                if not sent:
                    return None
                total_sent += sent

            return total_sent

    def write_msgpack(self, obj, timeout=None):
        return self.write(msgpack.packb(obj), timeout)

    def read_msgpack(self, readby=0x2000):
        while True:
            try:
                buf = self.recv(readby)
            except socket.error as ex:
                if ex.errno in (errno.EBADF, ):
                    return
                raise

            if not buf:
                # No more data to process here, connection closed or timed out
                return

            self.__msgpack_unpacker.feed(buf)
            for data in self.__msgpack_unpacker:
                stop = yield data
                if stop:
                    yield 0
                    return

    def bind(self, host, port):
        with self.__lock:
            assert not self.__binded
            self.sock.bind((host, port))
            self.__binded = True

    def connect(self, host, port, timeout=None):
        with self.__lock:
            assert not self.__connected

            with Socket.ScopedTimeout(self, timeout):
                for i in six.moves.xrange(3):
                    try:
                        if port is None:
                            self.sock.connect(host)
                        else:
                            self.sock.connect((host, port))
                    except socket.error:
                        if i == 2:
                            raise
                        self.__sleep(0.1)
                    else:
                        break

                self.__connected = True
