package ru.yandex.wmtools.common.data.info;

import ru.yandex.wmtools.common.error.UserException;
import ru.yandex.wmtools.common.error.UserProblem;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * @author avhaliullin
 */
public class IPv6Info extends IPInfo {
    //package-private:
    static final int FULL_MASK_LENGTH = IPvEnum.IPv6.getMaxMaskLength();
    static final BigInteger MAX_IP_ADDRESS = BigInteger.ONE
            .shiftLeft(128)
            .subtract(BigInteger.ONE);
    private static final int GROUP_MAX_VALUE = 0xffff;
    private static final int GROUPS_COUNT = 8;
    private static final int GROUP_SIZE_BITS = FULL_MASK_LENGTH / GROUPS_COUNT;

    static final BigInteger IPV4_BASE = BigInteger.ONE
            .multiply(BigInteger.valueOf(0xffff)).shiftLeft(32);

    private static final BigInteger IPV4_TEST_MASK = BigInteger.ONE
            .multiply(BigInteger.valueOf(0xffffffff))
            .shiftLeft(32);

    public static final BigInteger RESERVED_64_BITS;

    static {
        try {
            RESERVED_64_BITS = new IPv6Info("2001:DB8::", 32).getHigher64();
        } catch (UserException e) {
            throw new AssertionError(e);
        }
    }

    private final BigInteger address;

    private String addrStringCache = null;

    public IPv6Info(String stringForm, Integer maskLength) throws UserException {
        super(maskLength == null ? FULL_MASK_LENGTH : maskLength);
        try {
            int[] parts = asIntArray(stringForm);
            address = asBigInt(parts).shiftRight(FULL_MASK_LENGTH - this.maskLength).shiftLeft(FULL_MASK_LENGTH - this.maskLength);
            checkAddress();
        } catch (NumberFormatException e) {
            throw new UserException(UserProblem.ILLEGAL_PARAM_VALUE, "Wrong ip address or ip subnet: " + stringForm, e);
        }
    }

    public IPv6Info(BigInteger address, int maskLength) throws UserException {
        super(maskLength);
        this.address = address;
        checkAddress();
    }

    public IPv6Info(BigInteger higher64, BigInteger lower64, int maskLength) throws UserException {
        this(higher64.shiftLeft(64).add(lower64), maskLength);
    }

    private void checkAddress() throws UserException {
        if (address.compareTo(MAX_IP_ADDRESS) > 0) {
            throw new UserException(UserProblem.ILLEGAL_PARAM_VALUE, "IPv6 address has more than 128 bit length");
        }
    }

    @Override
    protected void checkMaskLength() throws UserException {
        if (maskLength > FULL_MASK_LENGTH) {
            throw new UserException(UserProblem.ILLEGAL_PARAM_VALUE, "Mask length is more than 128: " + maskLength);
        }
    }

    private static int[] asIntArray(BigInteger address) {
        int[] res = new int[GROUPS_COUNT];
        int index = 0;
        for (int shift = FULL_MASK_LENGTH - GROUP_SIZE_BITS; shift >= 0; shift -= GROUP_SIZE_BITS) {
            res[index++] = address.shiftRight(shift).mod(BigInteger.valueOf(GROUP_MAX_VALUE + 1)).intValue();
        }
        return res;
    }

    private static int[] asIntArray(String address) throws UserException, NumberFormatException {
        if (address.startsWith(":")) {
            address = "0" + address;
        }
        if (address.endsWith(":")) {
            address = address + "0";
        }
        List<String> parts = Arrays.asList(address.split(":"));

        // Case when :: used in already full address
        if (parts.size() > GROUPS_COUNT) {
            List<String> newParts = new ArrayList<String>();
            for (String part : parts) {
                if (!part.isEmpty()) {
                    newParts.add(part);
                }
            }
            parts = newParts;
        }

        // Handling "::ffff:x.x.x.x"
        List<String> newParts = new ArrayList<String>();
        for (String part : parts) {
            if (part.contains(".")) {
                IPv4Info ipv4 = new IPv4Info(part, 32);
                newParts.add(Long.toHexString(ipv4.getAddress().longValue() >> 16));
                newParts.add(Long.toHexString(ipv4.getAddress().longValue() % (1 << 16)));
            } else {
                newParts.add(part);
            }
        }
        parts = newParts;

        if (parts.size() > GROUPS_COUNT) {
            throw new UserException(UserProblem.ILLEGAL_PARAM_VALUE, "IPv6 address has no more than " + GROUPS_COUNT + " numeric groups");
        }

        // Find out, how many groups are missing, and check that :: appear at most once
        int zeroGroupsCount = 0;
        for (String part : parts) {
            if (part.isEmpty()) {
                if (zeroGroupsCount == 0) {
                    zeroGroupsCount = GROUPS_COUNT - (parts.size() - 1);
                } else {
                    throw new UserException(UserProblem.ILLEGAL_PARAM_VALUE, "Wrong ip address or ip subnet: " + address);
                }
            }
        }
        // If there is no :: - which means "substitute here as much zero groups, as you want", then groups count can't be less than 8
        if (zeroGroupsCount == 0 && parts.size() < GROUPS_COUNT) {
            throw new UserException(UserProblem.ILLEGAL_PARAM_VALUE, "Wrong ip address or ip subnet: " + address);
        }

        int[] result = new int[GROUPS_COUNT];
        int index = 0;
        for (String part : parts) {
            if (part.isEmpty()) {
                index += zeroGroupsCount;
            } else {
                int partValue = Integer.parseInt(part, 16);
                if (partValue >= 1 << 16) {
                    throw new UserException(UserProblem.ILLEGAL_PARAM_VALUE, "Wrong ip address or ip subnet: " + address);
                }
                result[index++] = partValue;
            }
        }
        return result;
    }

    private static BigInteger asBigInt(int[] address) {
        BigInteger addressAccum = BigInteger.ZERO;
        for (int part : address) {
            addressAccum = addressAccum.shiftLeft(16);
            addressAccum = addressAccum.add(BigInteger.valueOf(part));
        }
        return addressAccum;
    }

    @Override
    public boolean isSingleAddress() {
        return maskLength == FULL_MASK_LENGTH;
    }

    @Override
    public boolean isAllInclusive() {
        return maskLength == 0;
    }

    @Override
    public IPvEnum getProtocolVersion() {
        return IPvEnum.IPv6;
    }

    @Override
    public String getAddressAsString(boolean withDimension) {
        if (addrStringCache != null) {
            return addrStringCache;
        }
        int[] parts = asIntArray(address);


        int zeroGroupSize = 0;
        int zeroGroupOffset = 0;
        Integer[] size2Offset = new Integer[GROUPS_COUNT];
        for (int i = 0; i < parts.length; i++) {
            if (parts[i] == 0) {
                if (zeroGroupSize == 0) {
                    zeroGroupSize = 1;
                    zeroGroupOffset = i;
                } else {
                    zeroGroupSize++;
                }
                size2Offset[zeroGroupSize - 1] = zeroGroupOffset;
            } else {
                zeroGroupSize = 0;
            }
        }
        if (zeroGroupSize > 0) {
            size2Offset[zeroGroupSize - 1] = zeroGroupOffset;
        }
        for (int i = GROUPS_COUNT - 2; i >= 0; i--) {
            if (size2Offset[i] != null) {
                zeroGroupSize = i + 1;
                zeroGroupOffset = size2Offset[i];
                break;
            }
        }
        StringBuilder result = new StringBuilder();
        for (int i = 0; i < GROUPS_COUNT; i++) {
            if (zeroGroupSize > 0 && i >= zeroGroupOffset && i < zeroGroupOffset + zeroGroupSize) {
                if (zeroGroupOffset == i) {
                    result.append(":");
                }
            } else {
                if (i > 0) {
                    result.append(":");
                }
                result.append(String.format("%x", parts[i]));
            }
        }
        if (zeroGroupSize + zeroGroupOffset == GROUPS_COUNT) {
            result.append(":");
        }
        if (withDimension && !isSingleAddress()) {
            result.append("/")
                    .append(maskLength);
        }
        addrStringCache = result.toString();
        return addrStringCache;
    }

    @Override
    public boolean matches(IPInfo ip) {
        if (!(ip instanceof IPv6Info)) {
            return ip instanceof IPv4Info && matches(((IPv4Info) ip).castToIPv6());
        }
        IPv6Info info = (IPv6Info) ip;
        return info.getMaskLength() >= maskLength && info.address.shiftRight(FULL_MASK_LENGTH - maskLength).equals(address.shiftRight(FULL_MASK_LENGTH - maskLength));
    }

    @Override
    public BigInteger getLower64() {
        return address.and(MAX_IP_ADDRESS.shiftRight(FULL_MASK_LENGTH / 2));
    }

    @Override
    public BigInteger getHigher64() {
        return address.shiftRight(FULL_MASK_LENGTH / 2);
    }

    @Override
    public boolean equals(Object o) {
        if (o == null) {
            return false;
        }
        if (!(o instanceof IPv6Info)) {
            return false;
        }
        IPv6Info info = (IPv6Info) o;
        return info.address.equals(address) && info.maskLength == maskLength;
    }

    public boolean isCorrectIPv4() {
        return address.and(IPV4_TEST_MASK).equals(IPV4_BASE);
    }

    public IPv4Info castToIPv4() {
        if (!isCorrectIPv4()) {
            return null;
        }
        try {
            return new IPv4Info(address.mod(BigInteger.valueOf(IPv4Info.MAX_IP_ADDRESS + 1)).longValue(), Math.max(0, maskLength - FULL_MASK_LENGTH + IPv4Info.FULL_MASK_LENGTH));
        } catch (UserException e) {
            throw new AssertionError("Unbelievable - failed to cast IPv6 to IPv4 " + getAddressAsString(true));
        }
    }
}
