"""
RPC-specific utility classes and functions.
"""
import struct

from .. import enum
from .. import patterns

# Size of session ID in bits (connection magic should also fit into this size)
SID_BITS = 32


class Message(enum.Enum):
    """ Protocol control messages. """

    class Group(enum.GroupEnum):
        REQUESTS = None         # Client request messages
        RESPONSES = None        # Server response messages
        CALL_RESPONSES = None   # Valid responses to remote method call

    with Group.REQUESTS:
        CALL = None         # Ask to start remote method execution
        FEEDBACK = None     # Request with intermediate client's data (`send` statement on the client side)
        PING = None         # Ask to check the connection is still alive during remove call execution
        DROP = None         # Requests remote method execution termination (should be done for every registered calls)

    with Group.RESPONSES:
        REGISTERED = None   # Response to `CALL` with job ID message - acknowledges the remote method execution started
        STATE = None        # Response with intermediate remote call result (`yield` statement on the server side)
        PONG = None         # Response to connection check
        ERROR = None        # Message from the server with some connection-level problems. Requires connection close
        FAILED = None       # Response indicates the remote method finished with some exception (`raise` statement)
        COMPLETE = None     # Response indicates the remote method finished with some result (`return` statement)
        RECONNECT = None    # Message from the server that requests client reconnect (server graceful restart)


Message.Group.CALL_RESPONSES = [Message.REGISTERED, Message.STATE, Message.FAILED, Message.COMPLETE, Message.PONG]


class ConnectionHandler(object):
    __age_packer = struct.Struct("B")
    __sid_packer = struct.Struct("!I")

    def __init__(self, cfg, age):
        self.cfg = cfg
        self.age = age
        self.__magic = cfg.connection_magic
        assert 0 <= self.__magic < 2 ** SID_BITS

        super(ConnectionHandler, self).__init__()

    @patterns.classproperty
    def legacy_age(self):
        return self.__age_packer.pack(-1)

    @property
    def magic(self):
        return self.__magic

    @classmethod
    def handle_greetings(cls, sock, timeout):
        """
        Receives and parses a binary sequence as greetings packet.

        :param sock:    Socket to be read
        :param timeout: Timeout for socket read operation
        :return:        amount of bytes send in case of `sid` provided or age and session ID
                        from the remote side otherwise.
        """
        data = sock.read(cls.__age_packer.size + cls.__sid_packer.size, timeout=timeout)
        if data is None:
            return None, None
        return cls.__age_packer.unpack(data[:1])[0], cls.__sid_packer.unpack(data[1:])[0]

    def send_greetings(self, sock):
        """
        Sends age and magic on a newly connected socket.

        :param sock:    socket object to operate on
        :return:        amount of bytes send in case of `send` is `True` or age and magic from the remote side.
        """
        return sock.write(
            self.__age_packer.pack(self.age) + self.__sid_packer.pack(self.__magic),
            timeout=self.cfg.handshake_send_timeout
        )

    def handle_session(self, sock, sid=None):
        """
        Sends or receives session ID on an already established connection.

        :param sock:    socket object to operate on
        :param sid:     send session ID provided
        :return:        amount of bytes send in case of `sid` provided or a tuple of age and session ID
                        from the remote side otherwise.
        """
        if sid:
            return sock.write(
                self.__age_packer.pack(self.age) + self.__sid_packer.pack(sid),
                timeout=self.cfg.handshake_send_timeout
            )
        # Parsing of incoming data
        data = sock.read(self.__age_packer.size + self.__sid_packer.size, timeout=self.cfg.handshake_receive_timeout)
        return (self.__age_packer.unpack(data[:1])[0], self.__sid_packer.unpack(data[1:])[0]) if data else (None, None)


def sid2str(sid, bits=SID_BITS):
    return hex(sid)[2:].rstrip('L').zfill(bits // 4).lower()
