package ru.yandex.passport.util;

import java.math.BigInteger;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.util.Collections;
import java.util.*;


class TrieNode <T> {
    protected void makeEnd(){
        bIsEnd = true;
    }
    protected boolean isEnd() {
        return bIsEnd;
    }
    public Set<T> getData() {
        return this.data;
    }
    protected void addData(T data) {
        this.data.add(data);
    }

    public TrieNode<T> zeroNode = null;
    public TrieNode<T> oneNode = null;
    private boolean bIsEnd = false;
    private Set<T> data = new HashSet<>();
}


public class Trie<T> {
    // Constructor
    public Trie() {
        root = new TrieNode<>();
    }

    private static String padLeft(String s, int n) {
        return String.join("", Collections.nCopies(n - s.length(), "0")) + s;
    }

    public void add(InetAddress baseAddress, int size, T data)  {
        String binaryAddress = new BigInteger(1, baseAddress.getAddress()).toString(2);

        // Необходимо дописать слева ведущие нули, иначе могут возникнуть наложения адресов
        if (baseAddress instanceof Inet6Address) binaryAddress = padLeft(binaryAddress, 128);
        else binaryAddress = padLeft(binaryAddress, 32);
        TrieNode<T> crawl = root;

        for(int level = 0; level < size; level++)
        {
            char ch = binaryAddress.charAt(level);

            if (ch == '0' && crawl.zeroNode != null) crawl = crawl.zeroNode;
            else if (ch == '1' && crawl.oneNode != null) crawl = crawl.oneNode;
            else
            {
                TrieNode<T> temp = new TrieNode<>();
                if (ch == '0') crawl.zeroNode = temp;
                else crawl.oneNode = temp;
                crawl = temp;
            }
        }
        crawl.makeEnd();
        crawl.addData(data);
    }

    /*
     * При обходе сохраняем последнюю увиденную терминальную вершину,
     * она и будет ответом по завершению обхода.
     */
    public Set<T> getConsumers(InetAddress address)  {
        String binaryAddress = new BigInteger(1, address.getAddress()).toString(2);

        if (address instanceof Inet6Address) binaryAddress = padLeft(binaryAddress, 128);
        else binaryAddress = padLeft(binaryAddress, 32);

        TrieNode<T> crawl = root;

        Set<T> lastConsumers = new HashSet<>();
        if (root.getData() != null)
            lastConsumers.addAll(root.getData());

        for(int level = 0; level < binaryAddress.length(); level++) {
            char ch = binaryAddress.charAt(level);

            if (ch == '0' && crawl.zeroNode != null) crawl = crawl.zeroNode;
            else if (ch == '1' && crawl.oneNode != null) crawl = crawl.oneNode;
            else break;

            if (crawl.isEnd())
                lastConsumers.addAll(crawl.getData());
        }
        return lastConsumers;
    }

    public boolean checkConsumer(T consumer, InetAddress address)  {
        String binaryAddress = new BigInteger(1, address.getAddress()).toString(2);

        if (address instanceof Inet6Address) binaryAddress = padLeft(binaryAddress, 128);
        else binaryAddress = padLeft(binaryAddress, 32);

        TrieNode<T> crawl = root;

        if (root.getData().contains(consumer))
            return true;

        for(int level = 0; level < binaryAddress.length(); level++) {

            char ch = binaryAddress.charAt(level);

            if (ch == '0' && crawl.zeroNode != null) crawl = crawl.zeroNode;
            else if (ch == '1' && crawl.oneNode != null) crawl = crawl.oneNode;
            else break;

            if (crawl.isEnd() && crawl.getData().contains(consumer))
                return true;
        }
        return false;
    }

    private TrieNode<T> root;
}