package ru.yandex.chemodan.videostreaming.framework.ignite;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import javax.cache.processor.MutableEntry;

import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.cache.CacheEntryProcessor;
import org.apache.ignite.configuration.CacheConfiguration;
import org.jetbrains.annotations.NotNull;
import org.joda.time.Duration;
import org.joda.time.Instant;

import ru.yandex.chemodan.videostreaming.framework.web.BannedIpRegistry;
import ru.yandex.chemodan.videostreaming.framework.web.BannedIpUpdater;
import ru.yandex.commune.dynproperties.DynamicProperty;
import ru.yandex.misc.ip.IpAddress;
import ru.yandex.misc.ip.IpVersion;
import ru.yandex.misc.ip.NetworkMask;
import ru.yandex.misc.log.mlf.Logger;
import ru.yandex.misc.log.mlf.LoggerFactory;

/**
 * @author Dmitriy Amelin (lemeh)
 */
public class IgniteBannedIpUpdater implements BannedIpUpdater {
    private static final Logger logger = LoggerFactory.getLogger(IgniteBannedIpUpdater.class);

    private static final DynamicProperty<Integer> ipv4MaskBits =
            new DynamicProperty<>("streaming-banned-ip-v4-mask", 8,
                    value -> value >= 0 && value <= IpVersion.V4.getType().bitLength());

    private static final DynamicProperty<Integer> ipv6MaskBits =
            new DynamicProperty<>("streaming-banned-ip-v6-mask", 32,
                    value -> value >= 0 && value <= IpVersion.V6.getType().bitLength());

    private static final DynamicProperty<Integer> ipMismatchBanLimit =
            DynamicProperty.cons("streaming-ip-mismatch-ban-limit", 100);

    private static final DynamicProperty<Long> ipMismatchBanTtl = DynamicProperty.cons(
            "streaming-ip-mismatch-ban-ttl-minutes",
            Duration.standardHours(1).getStandardMinutes()
    );
    private final Ignite ignite;

    private final CacheConfiguration<IpAddress, IpAddresses> ipMismatchCacheConfig;

    private final BannedIpRegistry bannedIpRegistry;

    private final ExecutorService executor = new ThreadPoolExecutor(50, 50,
            0L, TimeUnit.MILLISECONDS,
            new LinkedBlockingQueue<>(10000)
    );

    public IgniteBannedIpUpdater(Ignite ignite,
            CacheConfiguration<IpAddress, IpAddresses> ipMismatchCacheConfig, BannedIpRegistry bannedIpRegistry)
    {
        this.ignite = ignite;
        this.ipMismatchCacheConfig = ipMismatchCacheConfig;
        this.bannedIpRegistry = bannedIpRegistry;
    }

    @Override
    public void update(IpAddress expectedIp, IpAddress requestIp, Object sourceMeta) {
        if (expectedIp.equals(requestIp)) {
            return;
        }

        try {
            executor.submit(() -> addMismatchAndCheckLimit(expectedIp, requestIp, sourceMeta));
        } catch (RejectedExecutionException ex) {
            logger.warn("Rejected execution for mismatched pair: {} != {}", expectedIp, requestIp);
        }
    }

    private void addMismatchAndCheckLimit(IpAddress expectedIp, IpAddress requestIp, Object sourceMeta) {
        boolean mismatchLimitExceeded = getMismatchesCache().invoke(expectedIp, new AddMismatchAndCheckLimit(requestIp));
        if (!mismatchLimitExceeded || bannedIpRegistry.has(expectedIp)) {
            return;
        }

        logger.warn("IP mismatch limit exceeded for IP={} with meta={} - IP banned", expectedIp, sourceMeta);
        bannedIpRegistry.add(expectedIp);
    }

    @SuppressWarnings("WeakerAccess")
    public IgniteCache<IpAddress, IpAddresses> getMismatchesCache() {
        return ignite.getOrCreateCache(ipMismatchCacheConfig);
    }

    private static IpAddress applyMask(IpAddress ip) {
        Integer bitsCount = ip.isIpv4Address() ? ipv4MaskBits.get() : ipv6MaskBits.get();
        return new NetworkMask(ip.getVersion().getType().bitLength() - bitsCount, ip.getVersion())
                .applyTo(ip);
    }

    public static class AddMismatchAndCheckLimit implements CacheEntryProcessor<IpAddress, IpAddresses, Boolean> {
        private final IpAddress requestIp;

        private AddMismatchAndCheckLimit(IpAddress requestIp) {
            this.requestIp = requestIp;
        }

        @Override
        public Boolean process(MutableEntry<IpAddress, IpAddresses> entry, Object... arguments) {
            try {
                return doProcess(entry);
            } catch (RuntimeException ex) {
                logger.error("Error while processing pair: {} - {}", entry.getKey(), requestIp, ex);
                throw ex;
            }
        }

        @NotNull
        private Boolean doProcess(MutableEntry<IpAddress, IpAddresses> entry) {
            if (!entry.exists()) {
                entry.setValue(new IpAddresses());
            }

            IpAddresses ipAddresses = entry.getValue();
            IpAddress requestIpSubnet = applyMask(requestIp);
            if (!entry.getValue().contains(requestIpSubnet)) {
                logger.info("Registering subnet = {} for mismatched IP (request != expected): {} != {}",
                        requestIpSubnet, requestIp, entry.getKey());
                ipAddresses.add(requestIpSubnet);
                entry.setValue(ipAddresses);
            }

            return ipAddresses.size() > ipMismatchBanLimit.get();
        }
    }

    public static class IpAddresses implements Serializable {
        Map<IpAddress, Instant> ips = new HashMap<>();

        void add(IpAddress ip) {
            ips.put(ip, Instant.now());
            evict();
        }

        private void evict() {
            Duration ttl = Duration.standardMinutes(ipMismatchBanTtl.get());
            ips = ips.entrySet()
                    .stream()
                    .filter(ipLastUpdate ->
                            new Duration(ipLastUpdate.getValue(), Instant.now())
                                    .isShorterThan(ttl)
                    )
                    .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
        }

        public boolean contains(IpAddress ip) {
            return ips.containsKey(ip);
        }

        public int size() {
            return ips.size();
        }

        @Override
        public String toString() {
            return "ip-addresses-size=" + size();
        }
    }
}
