import gevent.socket
import gevent.server
import gevent.event

import os
from collections import defaultdict

from socket import AF_INET, AF_INET6, SO_REUSEADDR, SOL_SOCKET, IPPROTO_IPV6, IPV6_V6ONLY, SHUT_RDWR, SOCK_STREAM, error as SocketError, has_ipv6

try:
    from socket import AF_UNIX
except ImportError:
    AF_UNIX = object()

from .. import logging

from .getifaddrs import getIfAddrs, Flags
from .misc import packIP, getFreePort
from ..functional import cached

import six


__all__ = [
    'getValidAddrs',
    'getInvalidAddrs',
    'SocketServer',
]


log = logging.getLogger('ya.skynet.util.net.socket_server')


def _isValid(address, families):
    return (
        (address.flags & Flags.IFF_UP) > 0
        and address.family in families
        and address.addr
        and address.netmask
    )


def splitAddrs(families, ifAddrs):
    validAddrs = []
    invalidAddrs = []
    for _ifName, addresses in ifAddrs.iteritems():
        notUsed = set()

        for i in six.moves.xrange(len(addresses)):
            if i in notUsed:
                continue

            addressI = addresses[i]

            if not _isValid(addressI, families):
                notUsed.add(i)
                continue

            # 192.168.0.1 & 255.255.255.0 -> 192.168.0.0
            maskedIpI = packIP(addressI.addr, addressI.family) & packIP(addressI.netmask, addressI.family)
            for j in six.moves.xrange(len(addresses)):
                addressJ = addresses[j]
                if (
                    i == j
                    or j in notUsed
                    or addressI.family != addressJ.family
                    or not addressJ.addr
                    or not addressJ.netmask
                ):
                    continue

                # 192.168.0.142 & 255.255.255.128 -> 192.168.0.128
                maskedIpJ = packIP(addressJ.addr, addressJ.family) & packIP(addressJ.netmask, addressJ.family)
                # 192.168.0.142 & 255.255.255.0 -> 192.168.0.0
                maskedIpJI = packIP(addressJ.addr, addressJ.family) & packIP(addressI.netmask, addressI.family)
                if (
                    (maskedIpJ >= maskedIpI == maskedIpJI)  # subnet J is part of subnet I
                    and addressJ.netmask > addressI.netmask
                ):
                    notUsed.add(j)

        for i in six.moves.xrange(len(addresses)):
            addressI = addresses[i]
            if addressI.family not in families:
                continue

            if i in notUsed:
                invalidAddrs.append(addressI)
            else:
                validAddrs.append(addressI)
    return validAddrs, invalidAddrs


def getValidAddrs(families, ifAddrs):
    return splitAddrs(families, ifAddrs)[0]


def getInvalidAddrs(families, ifAddrs):
    return splitAddrs(families, ifAddrs)[1]


class SocketServer(object):
    """
    Simple gevent driven socket server
    """

    log = log

    def __str__(self):
        return '{0} [{1}]'.format(self.__class__.__name__, self.__addressSpec)

    def __init__(self, addressSpec, handler, backlog=1024):
        """
        addressSpec - port or path
        handler - callback function handler(socket, address)
        backlog - backlog
        """

        self.__addressSpec = addressSpec
        self.__handler = handler
        self.__backlog = backlog
        self.__clear()

    def _unmangle(self, address):
        if isinstance(address, tuple) and address[0].startswith('::ffff:'):  # IPv4 over IPv6
            address = (address[0][7:], address[1])
        return address

    def _handler(self, sock, address):
        address = self._unmangle(address)
        ownAddress = self._unmangle(sock.getsockname())

        if ownAddress[0] in self.invalidAddresses():
            self.log.debug("Incoming connection to denied interface: {0}:{1} -> {2}:{3}".format(address[0], address[1], ownAddress[0], ownAddress[1]))
            sock.close()
            return

        self.log.debug('Incoming connection from: {0}'.format(address))
        self.__handler(sock, address)

    def __clear(self):
        self.__sockets = []
        self.__streamServers = []
        self.__invalidAddrs = []
        self.__running = False

    @property
    def running(self):
        """
        :return: Whether socket server running or not
        """
        return self.__running

    @property
    def addressSpec(self):
        """
        :return: Address specification which could be port or path to AF_UNIX socket
        :rtype: int or string
        """
        return self.__addressSpec

    @cached(60)
    def bindedAddresses(self):
        if self.isUnixDomain:
            return {AF_UNIX: set([self.__addressSpec])}

        addrs = getValidAddrs(self.families, getIfAddrs())
        result = defaultdict(set)
        for addr in addrs:
            result[addr.family].add(addr.addr)
        return result

    @cached(60)
    def invalidAddresses(self):
        if self.isUnixDomain:
            return set()

        valid, invalid = splitAddrs(self.families, getIfAddrs())
        result = set()
        for addr in invalid:
            result.add(addr.addr)
        for addr in valid:
            result.discard(addr.addr)
        return result

    @property
    def isUnixDomain(self):
        return isinstance(self.__addressSpec, six.string_types)

    @property
    def families(self):
        return [AF_INET] + ([AF_INET6] if has_ipv6 else [])

    def start(self):
        """Start listening"""
        if self.__running:
            return
        self.__running = True
        try:
            for addr in self.invalidAddresses():
                log.debug('Address `[{0}]:{1}` is ignored'.format(addr, self.__addressSpec))

            self._createSockets()
            self._createServers()
            for streamServer in self.__streamServers:
                streamServer.start()
        except Exception:
            self.__clear()
            raise
        self.log.info('Started {0}'.format(self))

    def stop(self):
        """Stop listening"""
        if not self.__running:
            return
        for streamServer in self.__streamServers:
            streamServer.stop()

        for socket, _address in self.__sockets:
            try:
                socket.shutdown(SHUT_RDWR)
                socket.close()
            except EnvironmentError:
                pass

        self.log.info('Stopped {0}'.format(self))

        self.__clear()

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, *_):
        self.stop()

    def sockets(self):
        return (socketSpec[0] for socketSpec in self.__sockets)

    def _createSockets(self):
        families = self.families

        if self.__addressSpec == 0:
            self.__addressSpec = getFreePort(families)

        if os.name != 'posix':
            # windows
            for af in families:
                self.__sockets.append((gevent.socket.socket(af), ('', self.__addressSpec)))
                self.__sockets[-1][0].setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
            return

        # posix UDS
        if self.isUnixDomain:
            self.__sockets.append((gevent.socket.socket(AF_UNIX), self.__addressSpec))
            self.__sockets[-1][0].setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
            return

        # posix TCP
        if has_ipv6:
            self.__sockets.append(
                (
                    gevent.socket.socket(AF_INET6),
                    gevent.socket.getaddrinfo('::', self.__addressSpec, AF_INET6, SOCK_STREAM)[0][4]
                )
            )
            self.__sockets[-1][0].setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
            self.__sockets[-1][0].setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1)

        self.__sockets.append(
            (
                gevent.socket.socket(AF_INET),
                gevent.socket.getaddrinfo('0.0.0.0', self.__addressSpec, AF_INET, SOCK_STREAM)[0][4]
            )
        )
        self.__sockets[-1][0].setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)

    def _createServers(self):
        for sock, address in self.__sockets:
            try:
                self.log.debug('Binding to address: {0}'.format(address))
                if isinstance(address, six.string_types):
                    try:
                        os.unlink(address)
                    except EnvironmentError:
                        pass
                sock.bind(address)
            except SocketError as err:
                self.log.error(str(err))
                raise
            sock.listen(self.__backlog)
            self.__streamServers.append(gevent.server.StreamServer(sock, self._handler))
