import logging
from collections import defaultdict, namedtuple
from ipaddress import ip_address

from pygtrie import Trie
from radix import Radix

log = logging.getLogger(__name__)

PROJECT_ID_BIT_LENGTH = 32  # до 8 hex-цифр
PROJECT_ID_BIT_START = 64  # с 64 по 95 биты включительно
BITS_AFTER_PROJECT_ID = 128 - PROJECT_ID_BIT_START - PROJECT_ID_BIT_LENGTH
PROJECT_ID_MASK = 0xFFFFFFFF << PROJECT_ID_BIT_LENGTH


def extract_project_id(ipaddress):
    project_id = (int(ipaddress) & PROJECT_ID_MASK) >> BITS_AFTER_PROJECT_ID
    # pad proj id with zeroes to 8 symbols (32 / 4 bits per hex digit)
    return _pad_project_id("{:x}".format(project_id))


ParsedNetwork = namedtuple("ParsedNetwork", "network project_id range_prefixlen")


def parse_network(network_expr):
    if "@" not in network_expr:  # usual network
        return ParsedNetwork(network=network_expr, project_id=None, range_prefixlen=None)
    else:
        trypo, network = network_expr.split("@", 1)
        if "/" not in trypo:  # trypo network (project id@network)
            return ParsedNetwork(network=network, project_id=trypo, range_prefixlen=None)
        else:  # trypo range network (project id prefix/prefixbitlen@network)
            project_id_range, range_prefixlen = trypo.split("/", 1)
            return ParsedNetwork(network=network, project_id=project_id_range, range_prefixlen=range_prefixlen)


class TRYPOCompatibleRadix:
    """
    Radix tree of networks addresses, supporting matching against TRYPO networks (with embedded project id) and their
    ranges.
    Version without trypo network range support taken from here:
    https://a.yandex-team.ru/arc/trunk/arcadia/passport/python/core/grants/trypo_compatible_radix.py
    """

    def __init__(self):
        self._radix = Radix()
        self._project_id_to_radix = defaultdict(Radix)
        self._project_id_range_to_radix = Trie()

    def add(self, network):
        parsed_network = parse_network(network)
        if parsed_network.range_prefixlen is not None:
            self._add_trypo_range_network(parsed_network)
        elif parsed_network.project_id is not None:
            return self._add_trypo_network(parsed_network)
        else:
            return self._radix.add(network)

    def search_best(self, ip):
        # search in three places:
        # * project id range @ network (v6 only)
        # * single project id @ network (v6 only)
        # * plain network (both v4 and v6)
        address = ip_address(ip)
        classic_node = self._radix.search_best(ip)
        trypo_node = None
        if address.version == 6:
            expected_project_id = extract_project_id(address)

            trypo_node = self._search_trypo_range_networks(expected_project_id, ip)
            if not trypo_node and expected_project_id in self._project_id_to_radix:
                trypo_node = self._project_id_to_radix[expected_project_id].search_best(ip)
        if not classic_node or (trypo_node and trypo_node.prefixlen + PROJECT_ID_BIT_LENGTH >= classic_node.prefixlen):
            return trypo_node
        return classic_node

    @staticmethod
    def node_repr(node):
        as_str = node.prefix
        project_id = node.data.get("project_id")
        if project_id is not None:
            as_str = "{}@{}".format(project_id, as_str)
        return as_str

    def _add_trypo_network(self, parsed_network):
        project_id, netmask = parsed_network.project_id, parsed_network.network
        # pad with zeroes up to full (8 chars) form
        project_id = _pad_project_id(project_id)
        node = self._project_id_to_radix[project_id].add(netmask)
        node.data["project_id"] = project_id
        return node

    def _add_trypo_range_network(self, parsed_network):
        range_prefix = self._get_project_range_prefix(parsed_network)
        if range_prefix not in self._project_id_range_to_radix:
            self._project_id_range_to_radix[range_prefix] = Radix()
        node = self._project_id_range_to_radix[range_prefix].add(parsed_network.network)
        node.data["project_id"] = "{}/{}".format(parsed_network.project_id, parsed_network.range_prefixlen)
        return node

    def _search_trypo_range_networks(self, project_id, address):
        # there can be multiple matching prefixes, search in each prefix radixes
        for prefix_step in self._project_id_range_to_radix.prefixes(project_id):
            node = prefix_step.value.search_best(address)
            if node is not None:
                return node
        return None

    @staticmethod
    def _get_project_range_prefix(parsed_network):
        range_prefix, prefix_bit_len = parsed_network.project_id, parsed_network.range_prefixlen
        # 1000000/12 means project ids from 1000000 to 10fffff
        # here we have only 7 symbols out of 8 (32 bits / 4 per one hex symbol), so we must pad with 0 up to 8 symbols,
        # then take only prefix_bit_len / 4 symbols as prefix (010 in this example)
        range_prefix = _pad_project_id(range_prefix)
        range_prefix = range_prefix[: int(prefix_bit_len) // 4]
        return range_prefix


def _pad_project_id(project_id):
    return "0" * (8 - len(project_id)) + project_id
