import bisect
import collections
import operator
import socket
import struct


class IPNetwork(object):

    __slots__ = (
        "_network_int",
        "_prefix_int",
        "_version",
        "_hostmask_int"
    )

    IPV4_WIDTH = 32
    IPV4_MAX_INT = 2 ** IPV4_WIDTH - 1

    IPV6_WIDTH = 128
    IPV6_MAX_INT = 2 ** IPV6_WIDTH - 1

    def __init__(self, network):
        self._network_int, self._prefix_int, self._version = self._parse_network(network)
        self._hostmask_int = (1 << (self._width - self._prefix_int)) - 1

    @property
    def _width(self):
        return self.IPV6_WIDTH if self._version == 6 else self.IPV4_WIDTH

    @property
    def _max_int(self):
        return self.IPV6_MAX_INT if self._version == 6 else self.IPV4_MAX_INT

    @staticmethod
    def _ipv4_to_int(addr):
        return struct.unpack("!I", socket.inet_pton(socket.AF_INET, addr))[0]

    @staticmethod
    def _ipv6_to_int(addr):
        hi, lo = struct.unpack('!QQ', socket.inet_pton(socket.AF_INET6, addr))
        return (hi << 64) | lo

    def _parse_network(self, network):
        addr, prefix = network.split("/")
        try:
            return self._ipv6_to_int(addr), int(prefix), 6
        except socket.error:
            return self._ipv4_to_int(addr), int(prefix), 4

    @classmethod
    def addr_to_int(cls, addr):
        try:
            return cls._ipv6_to_int(addr)
        except socket.error:
            return cls._ipv4_to_int(addr)

    @classmethod
    def extract_mtn_project(cls, addr):
        return hex(int((cls.addr_to_int(addr) & 0xFFFF00000000) >> 32))[2:]

    @classmethod
    def get_addr_family(cls, addr):
        try:
            cls._ipv6_to_int(addr)
        except socket.error:
            return socket.AF_INET
        else:
            return socket.AF_INET6

    @property
    def first(self):
        return self._network_int & (self._max_int ^ self._hostmask_int)

    @property
    def last(self):
        return self._network_int | self._hostmask_int

    def to_range(self):
        return (self.first, self.last)

    @property
    def version(self):
        return self._version


class NetworkSet(object):

    Entry = collections.namedtuple("Entry", ("range_table", "value_table"))
    Range = collections.namedtuple("Range", ("start", "end"))

    def __init__(self, network_list):
        self._root = self._build_root(network_list)

    @classmethod
    def _build_root(cls, network_list):
        root = cls.Entry(range_table=[], value_table=[])
        networks = (
            (cls.Range(*IPNetwork(network).to_range()), value)
            for network, value in network_list
        )
        it = iter(sorted(networks, key=operator.itemgetter(0)))
        while True:
            try:
                next_range, next_value = next(it)
            except StopIteration:
                break
            if not root.range_table or root.range_table[-1].end < next_range.start:
                # networks don't intersects or empty, simply add new one
                root.range_table.append(next_range)
                root.value_table.append(cls.Entry(range_table=[], value_table=[]))
            else:
                assert root.range_table[-1].start <= next_range.start, "list not sorted"
                root.range_table[-1] = cls.Range(
                    root.range_table[-1].start, max(root.range_table[-1].end, next_range.end))
            # our tree has two levels
            last_entry = root.value_table[-1]
            last_entry.range_table.append(next_range)
            last_entry.value_table.append(next_value)
        return root

    @staticmethod
    def _belong_to_range(addr_int, addr_range):
        return addr_range.start <= addr_int <= addr_range.end

    @classmethod
    def _find_position(cls, addr_int, range_table):
        pos = bisect.bisect_left(range_table, (addr_int, ))
        if pos > 0:
            if cls._belong_to_range(addr_int, range_table[pos - 1]):
                return pos - 1
        if pos < len(range_table):
            if cls._belong_to_range(addr_int, range_table[pos]):
                return pos
        return None

    @classmethod
    def _scan_position(cls, addr_int, range_table):
        for pos in reversed(xrange(len(range_table))):
            if cls._belong_to_range(addr_int, range_table[pos]):
                return pos
        return None

    @classmethod
    def _get_value(cls, addr_int, entry, scan=False):
        if scan:
            position = cls._scan_position(addr_int, entry.range_table)
        else:
            position = cls._find_position(addr_int, entry.range_table)
        return entry.value_table[position] if position is not None else None

    def get(self, addr):
        addr_int = IPNetwork.addr_to_int(addr)
        return self.get_by_int(addr_int)

    def get_by_int(self, addr_int):
        child_entry = self._get_value(addr_int, self._root)
        return self._get_value(addr_int, child_entry, scan=True) if child_entry is not None else None
