import errno
import fcntl
import gevent
import logging
import msgpack
import os
import socket
import traceback


class UDPServer(object):
    def __init__(self, port, callback, bind_v4=None, bind_v6=None):
        self.log = logging.getLogger('udp')
        self.port = port
        self.bind_v4 = bind_v4 or ''
        self.bind_v6 = bind_v6 or ''
        self.callback = callback
        self._stats = [0, 0, 0, 0]  # in/out

        self._msgpacker = msgpack.Packer()

        self.worker_grn4 = self.worker_grn6 = None

    def start(self):
        if not self.bind_v4:
            self.sock4 = None
        else:
            self.sock4 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

        self.sock6 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        if not self.sock4:
            self.sock6.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)

        for sock in (self.sock4, self.sock6):
            if not sock:
                continue

            sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4 * 1024 * 1024)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 16 * 1024 * 1024)

            fcntl.fcntl(
                sock.fileno(), fcntl.F_SETFL,
                fcntl.fcntl(sock.fileno(), fcntl.F_GETFL) | os.O_NONBLOCK
            )

        if self.sock4:
            self.sock4.bind((self.bind_v4, self.port))
            self.log.info('Binded (v4) on %s:%d (udp)' % (self.bind_v4 or '0.0.0.0', self.port))

        self.sock6.bind((self.bind_v6, self.port))
        self.log.info('Binded (v6) on %s:%d (udp)' % (self.bind_v6 or '0.0.0.0', self.port))

        if self.sock4:
            self.worker_grn4 = gevent.spawn(self.worker, self.sock4)

        if self.sock6:
            self.worker_grn6 = gevent.spawn(self.worker, self.sock6)

        return self

    def stop(self):
        for grn in self.worker_grn4, self.worker_grn6:
            if not grn:
                continue
            grn.kill()
            grn.join()

        if self.sock4:
            self.sock4.close()

        if self.sock6:
            self.sock6.close()

        return self

    def join(self):
        return self

    def worker(self, sock):
        cb = self.callback
        stats = self._stats
        sleep = gevent.sleep

        fd = sock.fileno()
        recvfrom = sock.recvfrom
        sendto = sock.sendto
        unpacker = msgpack.loads
        packer = self._msgpacker.pack

        idx = 0

        while True:

            try:
                try:
                    data, peer = recvfrom(2048)
                except socket.error as ex:
                    if ex.errno != errno.EAGAIN:
                        raise
                    gevent.socket.wait_read(fd)
                    data, peer = recvfrom(2048)

                datalen = len(data)
                data = unpacker(data)
            except Exception:
                self.log.warning('Unhandled exception: %s' % (traceback.format_exc(), ))
                sleep(0.1)  # avoid busy loops
                continue

            try:
                peer_ip, peer_port = peer[:2]
                stats[0] += 1
                stats[2] += datalen / 1024.
                reply = cb(data, peer_ip, peer_port)
                if reply is not None:
                    reply = packer(reply)
                    sendto(reply, (peer_ip, peer_port))
                    stats[1] += 1
                    stats[3] += len(reply) / 1024.

                if idx == 10000:
                    sleep()
                    idx = 0
                else:
                    idx += 1

            except Exception:
                self.log.warning('Unhandled exception: %s' % (traceback.format_exc(), ))
                sleep(0.1)  # avoid busy loops

    def get_stats(self, describe=False):
        if describe:
            return (
                ['udp in', 'counter', 'pkt/s', [1, 60, 300, 3600]],
                ['udp ou', 'counter', 'pkt/s', [1, 60, 300, 3600]],
                ['udp in (speed)', 'counter', 'Kb/s', [1, 60, 300, 3600]],
                ['udp ou (speed)', 'counter', 'Kb/s', [1, 60, 300, 3600]],
            )

        return self._stats
