import errno
import socket
import unittest

import six

from ..streams.streambase import StreamBase


__all__ = [
    "SocketStream"
]


class SocketStream(StreamBase):
    def __init__(self, sock):
        self.sock = sock
        StreamBase.__init__(self, self)

    def read(self, sz):
        received = []
        receivedLen = 0
        while receivedLen < sz:
            received.append(self.sock.recv(min(sz - receivedLen, 8192)))
            if not received[-1]:
                raise socket.error(errno.ECONNRESET, "Connection reset by peer")
            receivedLen += len(received[-1])
        return b"".join(received)

    recv = read

    def write(self, data):
        self.sock.sendall(data)

    send = write


class TestSocketStream(unittest.TestCase):
    def testGevent(self):
        import gevent
        import gevent.socket as socket
        import gevent.event

        from kernel.util.net.misc import getFreePort
        from kernel.util.net.socketserver import SocketServer

        port = getFreePort()

        data = [
            12930123980172381238374902183103401389103191039812031892038279043283904249824,
            "asjdlkajdlkad" * 107,
            "asjjdlkad" * 908,
            (123, 8979)
        ]

        event = gevent.event.Event()

        def sendData():
            sendSocket = socket.socket()
            sendSocket.connect(("127.0.0.1", port))
            stream = SocketStream(sendSocket)
            for d in data:
                stream.writeObj(d)

        def receiveData(socket, address):
            localData = []
            stream = SocketStream(socket)

            for i in six.moves.xrange(len(data)):
                localData.append(stream.readObj())
            self.assertEquals(data, localData)
            event.set()

        ss = SocketServer(port, receiveData)
        try:
            ss.start()
            let = gevent.Greenlet(sendData)
            let.start()
            let.join()
            event.wait()
        finally:
            try:
                ss.stop()
            except:
                pass


if __name__ == "__main__":
    unittest.main()
