package ru.yandex.util.ip;

import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.regex.Pattern;

public enum IpRangeSet implements IpSet<Map<Integer, BitSet>> {
    INSTANCE;

    private static final int IPV4_LEN = 4;
    private static final int BYTE_OFFSET = 8;
    private static final int INT_MASK = 0xff;

    @Override
    public Builder<Map<Integer, BitSet>> createBuilder() {
        return new IpRangeSetBuilder();
    }

    private static int index(final byte b1, final int b2) {
        return ((b1 & INT_MASK) << BYTE_OFFSET) | b2;
    }

    private static int index(final byte[] address, final int offset) {
        int index =
            ((address[offset] & INT_MASK) << BYTE_OFFSET)
            | (address[offset + 1] & INT_MASK);
        return index;
    }

    @Override
    public boolean contains(
        final Map<Integer, BitSet> ips,
        final InetAddress ip)
    {
        byte[] address = ip.getAddress();
        boolean result;
        if (address.length == IPV4_LEN) {
            BitSet bitSet = ips.get(index(address, 0));
            if (bitSet == null) {
                result = false;
            } else {
                result = bitSet.get(index(address, 2));
            }
        } else {
            result = false;
        }
        return result;
    }

    private static class IpRangeSetBuilder
        implements Builder<Map<Integer, BitSet>>
    {
        private static final Pattern SPLIT_PATTERN = Pattern.compile("\\.");
        private static final int IPV4_MAX = 255;
        private static final int IPV4_MAX_DOTS = IPV4_LEN - 1;
        private static final int SUBNET16_SIZE = 65536;

        private final Map<Integer, BitSet> ips = new HashMap<>();

        private void addAddress(final byte[] address, final boolean negate) {
            Integer index = index(address, 0);
            if (negate) {
                BitSet bitSet = ips.get(index);
                if (bitSet != null) {
                    bitSet.clear(index(address, 2));
                }
            } else {
                BitSet bitSet =
                    ips.computeIfAbsent(index, BitSetFactory.INSTANCE);
                bitSet.set(index(address, 2));
            }
        }

        // CSOFF: ParameterNumber
        private void processInterval(
            final boolean negate,
            final boolean interval,
            final String[] parts,
            final int depth,
            final byte[] bytes)
            throws IOException
        {
            if (depth == 2) {
                Integer index = index(bytes, 0);
                if (parts.length == 2) {
                    if (negate) {
                        ips.remove(index);
                    } else {
                        BitSet bitSet = new BitSet(SUBNET16_SIZE);
                        bitSet.set(0, SUBNET16_SIZE);
                        ips.put(index, bitSet);
                    }
                } else {
                    BitSet bitSet;
                    if (negate) {
                        bitSet = ips.get(index);
                        if (bitSet == null) {
                            return;
                        }
                    } else {
                        bitSet =
                            ips.computeIfAbsent(index, BitSetFactory.INSTANCE);
                    }
                    processInterval2(
                        negate,
                        interval,
                        parts,
                        depth,
                        bytes,
                        bitSet);
                }
            } else {
                int start;
                int end;
                if (depth >= parts.length) {
                    start = 0;
                    end = IPV4_MAX;
                } else {
                    String part = parts[depth];
                    int idx;
                    if (interval) {
                        idx = part.indexOf('-');
                    } else {
                        idx = -1;
                    }
                    if (idx >= 0) {
                        start = Integer.parseInt(part.substring(0, idx));
                        end = Integer.parseInt(part.substring(idx + 1));
                    } else {
                        start = Integer.parseInt(part);
                        end = start;
                    }
                }
                for (int i = start; i <= end; ++i) {
                    bytes[depth] = (byte) i;
                    processInterval(negate, interval, parts, depth + 1, bytes);
                }
            }
        }

        private static void processInterval2(
            final boolean negate,
            final boolean interval,
            final String[] parts,
            final int depth,
            final byte[] bytes,
            final BitSet bitSet)
            throws IOException
        {
            int start;
            int end;
            if (depth >= parts.length) {
                start = 0;
                end = IPV4_MAX;
            } else {
                String part = parts[depth];
                int idx;
                if (interval) {
                    idx = part.indexOf('-');
                } else {
                    idx = -1;
                }
                if (idx >= 0) {
                    start = Integer.parseInt(part.substring(0, idx));
                    end = Integer.parseInt(part.substring(idx + 1));
                } else {
                    start = Integer.parseInt(part);
                    end = start;
                }
            }
            if (depth == 2) {
                for (int i = start; i <= end; ++i) {
                    bytes[depth] = (byte) i;
                    processInterval2(
                        negate,
                        interval,
                        parts,
                        depth + 1,
                        bytes,
                        bitSet);
                }
            } else {
                int startIndex = index(bytes[2], start);
                int endIndex = index(bytes[2], end) + 1;
                if (negate) {
                    bitSet.clear(startIndex, endIndex);
                } else {
                    bitSet.set(startIndex, endIndex);
                }
            }
        }
        // CSON: ParameterNumber

        @Override
        public void add(final String ip) throws IOException {
            boolean negate = ip.charAt(0) == '!';
            String addressString;
            if (negate) {
                addressString = ip.substring(1).trim();
            } else {
                addressString = ip;
            }
            boolean interval = false;
            int dots = 0;
            int length = addressString.length();
            for (int i = 0; i < length; ++i) {
                char c = addressString.charAt(i);
                if (c == '.') {
                    ++dots;
                } else if (c == '-') {
                    interval = true;
                    break;
                }
            }
            try {
                if (!interval && dots == IPV4_MAX_DOTS) {
                    // Should be plain IPv4 address
                    InetAddress address = InetAddress.getByName(addressString);
                    addAddress(address.getAddress(), negate);
                } else if (dots > IPV4_MAX_DOTS) {
                    throw new UnknownHostException("Wrong number of dots");
                } else {
                    processInterval(
                        negate,
                        interval,
                        SPLIT_PATTERN.split(addressString),
                        0,
                        new byte[IPV4_LEN]);
                }
            } catch (RuntimeException e) {
                throw new IOException("Can't parse IP", e);
            }
        }

        @Override
        public Map<Integer, BitSet> build() {
            return ips;
        }
    }

    private enum BitSetFactory implements Function<Integer, BitSet> {
        INSTANCE;

        @Override
        public BitSet apply(final Integer index) {
            return new BitSet();
        }
    }
}

