import logging
import socket
import select
import sys
import time
import ssl
# import pprint


class ProtoTransportError(RuntimeError):
    def __init__(self, message):
        RuntimeError.__init__(self, message)


class ProtoTransport:
    def __init__(self, host, port, timeout=8, verbose=True, enable_ssl=False, ipv4=False, max_faults=0):
        self.timeout = timeout
        if not self.timeout:
            self.timeout = 3600  # 1h
        self.verbose = verbose
        self.max_faults = max_faults
        tries = 5
        while tries > 0:
            try:
                if self.verbose:
                    logging.debug('ProtoTransport: trying to connect [{}]:{}  (left tries={})'.format(host, port, tries))
                if enable_ssl:
                    s = socket.socket(socket.AF_INET if ipv4 else socket.AF_INET6, socket.SOCK_STREAM)
                    ssl_sock = ssl.wrap_socket(s)
                    ssl_sock.connect((host, port))
                    # logging.debug(repr(ssl_sock.getpeername()))
                    # logging.debug(ssl_sock.cipher())
                    # logging.debug(pprint.pformat(ssl_sock.getpeercert()))
                    self.socket = ssl_sock
                else:
                    self.socket = socket.create_connection((host, port), timeout)
                    self.socket.settimeout(timeout)
                return None
            except Exception as ex:
                tries -= 1
                time.sleep(1)
                if (tries == 0):
                    raise ex

    def __enter__(self):
        return self

    def send(self, data):
        faults = 0

        while True:
            try:
                rlist, wlist, xlist = select.select([], [self.socket], [self.socket], 0.1)
                if len(xlist):
                    raise ProtoTransportError("send unavailable!")
                if len(wlist):
                    break
            except Exception as e:
                if self.verbose:
                    logging.exception("Exception on pre-send select: ", str(e))
                faults += 1
                if faults > self.max_faults:
                    raise e
        while True:
            try:
                self.socket.send(data.encode("utf-8") if (sys.version_info[0] == 3 and type(data) == str) else data)
                break
            except Exception as e:
                if self.verbose:
                    logging.exception("Exception on send: ", str(e))
                faults += 1
                if faults > self.max_faults:
                    raise e
        if self.verbose:
            logging.debug("Send " + str(len(data)))

    def recv(self, length, decode=(sys.version_info[0] == 3)):
        res = b""
        faults = 0
        while True:
            try:
                rlist, _, xlist = select.select([self.socket], [], [self.socket], self.timeout)
                if xlist:
                    raise Exception("Socket has been closed")
                if not rlist:
                    raise Exception("Timeout on receive data from socket")

                tmp_res = self.socket.recv(length - len(res))
                if len(tmp_res) == 0:
                    raise Exception("Timeout on recv (read {} from {})".format(len(res), length))
                else:
                    res += tmp_res
                if len(res) < length:
                    continue
                else:
                    if decode:
                        return res.decode("utf-8")
                    else:
                        return res
            except Exception as e:
                if self.verbose:
                    logging.exception("Exception on recv: {}".format(e))
                faults += 1
                if faults > self.max_faults:
                    raise e

    def sendFull(self, message):
        begin = 0
        while begin < len(message):
            begin += self.socket.send(message[begin:])

    def sendMessage(self, message):
        self.socket.send(hex(len(message))[2:].encode("utf-8"))
        self.socket.send(b'\r\n')
        self.sendFull(message)
        if self.verbose:
            logging.debug("Send message size: {}".format(len(message)))

    def recvMessage(self):
        size = b''
        while True:
            symbol = self.recv(1)

            if len(symbol) == 0:
                raise ProtoTransportError('Backend closed connection')

            assert(len(symbol) == 1), 'Bad symbol len from socket ' + str(len(symbol))

            if symbol == b'\r':
                self.recv(1)
                break
            else:
                size += symbol
        sizeInt = int(b'0x' + size, 0)
        if self.verbose:
            logging.debug("Got message. Expecting {0} bytes length.".format(sizeInt))
        if (sizeInt > 0):
            result = b''
            while len(result) < sizeInt:
                result += self.recv(sizeInt - len(result), False)
            result = result
            assert (len(result) == sizeInt), 'Invalid message size'
            return result
        return ''

    def sendProtobuf(self, protobuf):
        self.sendMessage(protobuf.SerializeToString())

    def recvProtobuf(self, protobufType):
        response = protobufType()
        message = self.recvMessage()

        response.ParseFromString(message)

        return response

    def recvProtobufIfAny(self, protobuf):
        rlist, wlist, xlist = select.select([self.socket], [], [self.socket], 0)
        if (len(rlist)):
            return self.recvProtobuf(protobuf)
        else:
            return None

    def close(self):
        if self.verbose:
            logging.debug('Close socket: {}'.format(str(self.socket)))
        self.socket.close()

    def __exit__(self, type, value, traceback):
        self.close()
