package ru.yandex.webmaster3.storage.user.dao;

import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.UUID;

import org.apache.commons.lang3.StringUtils;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Required;
import org.springframework.stereotype.Service;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisCluster;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.exceptions.JedisConnectionException;
import redis.clients.jedis.exceptions.JedisException;

import ru.yandex.webmaster3.core.WebmasterException;
import ru.yandex.webmaster3.core.http.WebmasterErrorResponse;
import ru.yandex.webmaster3.core.util.TimeUtils;

/**
 * Created by ifilippov5 on 14.01.17.
 */
@Service
public class SpamBanCountersRedisService {
    private static final Logger log = LoggerFactory.getLogger(SpamBanCountersRedisService.class);

    static final String MASKS_COUNTER_PREFIX = "masksCounter2";
    static final String HITS_COUNTER_PREFIX = "hitsCounter2";

    static final String KEY_DELIMITER = "|";
    private static final DateTimeFormatter DATE_FORMAT = TimeUtils.DF_YYYYMMDD_HHMM.withZoneUTC();
    private static final DateTimeFormatter DATE_FORMAT_ONLY_HOUR = DateTimeFormat.forPattern("yyyyMMdd_HH").withZoneUTC();

    private final int EXPIRATION_SECONDS = (int) Duration.standardDays(1).getStandardSeconds();
    private final int EXPIRATION_MINUTES = (int) Duration.standardDays(1).getStandardMinutes();

    private JedisCluster cluster;

    public void increment(UUID ruleId) {
        increment(MASKS_COUNTER_PREFIX, ruleId.toString());
    }

    public void increment(String ip) {
        increment(HITS_COUNTER_PREFIX, ip);
    }

    private void increment(String keyPrefix, String keyId) {
        String key = buildKey(keyPrefix, keyId, serializeInstant(Instant.now()));
        try {
            boolean exists = cluster.exists(key);
            cluster.incr(key);
            if (!exists) cluster.expire(key, EXPIRATION_SECONDS); // Возможно race condition
        } catch (JedisException ex) {
            log.error("Failed to increment value by key {}", key, ex);
        }
    }

    private Map<String, NavigableMap<Instant, Long>> getCounters(String keyPrefix, String datePrefix) {
        String pattern;
        if (StringUtils.isEmpty(datePrefix)) {
            pattern = keyPrefix + KEY_DELIMITER + "*";
        } else {
            pattern = keyPrefix + KEY_DELIMITER + "*" + KEY_DELIMITER + datePrefix + "*";
        }
        Set<String> keys = new HashSet<>();

        Map<String, JedisPool> clusterNodes = cluster.getClusterNodes();
        for (String nodeKey : clusterNodes.keySet()) {
            JedisPool jpool = clusterNodes.get(nodeKey);
            try (Jedis connection = jpool.getResource()) {
                log.info("Getting keys from {}", nodeKey);
                Set<String> keysOnNode = connection.keys(pattern);

                Optional<String> minKey = keysOnNode.stream().min(String::compareTo);
                Optional<String> maxKey = keysOnNode.stream().max(String::compareTo);
                log.info("Keys found: {} node={} min={}, max={}", keysOnNode.size(), nodeKey,
                        minKey.orElse(""),
                        maxKey.orElse(""));
                keys.addAll(keysOnNode);
            } catch (Exception e) {
                log.warn("Unable to get keys from {}", nodeKey, e);
            }
        }

        return buildUsableCounter(keys);
    }

    Map<String, NavigableMap<Instant, Long>> buildUsableCounter(Set<String> keys) {
        Map<String, NavigableMap<Instant, Long>> counter = new HashMap<>();

        int invalidKeysCount = 0;

        for (String key : keys) {
            String[] items = StringUtils.split(key, KEY_DELIMITER);
            if (items.length != 3) {
                invalidKeysCount++;
                continue;
            }
            String prefix = items[0];
            if (!Objects.equals(HITS_COUNTER_PREFIX, prefix)) {
                invalidKeysCount++;
                continue;
            }

            try {
                String value = cluster.get(key);
                if (value != null) {
                    String ip = items[1];
                    String date = items[2];
                    counter.computeIfAbsent(ip, k -> new TreeMap<>())
                            .put(deserializeInstant(date), Long.valueOf(value));
                }
            } catch (JedisConnectionException ex) {
                throw new WebmasterException("Problems in connection with redis cluster",
                        new WebmasterErrorResponse.RedisErrorResponse(getClass(), "get " + key, ex), ex);
            }
        }
        if (invalidKeysCount > 0) {
            log.warn("Invalid keys found: {}", invalidKeysCount);
        }

        return counter;
    }

    static Instant deserializeInstant(String date) {
        return Instant.parse(date, DATE_FORMAT);
    }

    static String serializeInstant(Instant time) {
        return DATE_FORMAT.print(time);
    }

    static String buildKey(String keyPrefix, String keyId, String time) {
        return keyPrefix + KEY_DELIMITER + keyId + KEY_DELIMITER + time;
    }

    public Map<String, NavigableMap<Instant, Long>> getHitsCounters() {
        return getCounters(HITS_COUNTER_PREFIX, null);
    }

    public Map<String, NavigableMap<Instant, Long>> getHitsCountersLast2Hours() {
        Instant now = Instant.now();
        String thisHourPrefix = now.toString(DATE_FORMAT_ONLY_HOUR);
        Map<String, NavigableMap<Instant, Long>> thisHour = getCounters(HITS_COUNTER_PREFIX, thisHourPrefix);

        String previousHourPrefix = now.minus(Duration.standardHours(1)).toString(DATE_FORMAT_ONLY_HOUR);
        Map<String, NavigableMap<Instant, Long>> result = getCounters(HITS_COUNTER_PREFIX, previousHourPrefix);

        for (Map.Entry<String, NavigableMap<Instant, Long>> entry : thisHour.entrySet()) {
            result.computeIfAbsent(entry.getKey(), k -> new TreeMap<>()).putAll(entry.getValue());
        }
        return result;
    }

    public Map<String, NavigableMap<Instant, Long>> getMasksCounters() {
        return getCounters(MASKS_COUNTER_PREFIX, null);
    }

    public long getCount(UUID ruleId, String time) {
        String key = buildKey(MASKS_COUNTER_PREFIX, ruleId.toString(), time);
        return (cluster.exists(key) ? Long.parseLong(cluster.get(key)) : 0);
    }

    public long getAmount(UUID ruleId) {
        long amount = 0;

        Map<String, Long> counts = getCounts(ruleId);
        for (String key : counts.keySet()) {
            amount += counts.get(key);
        }

        return amount;
    }

    public Map<String, Long> getCounts(UUID ruleId) {
        Map<String, Long> counts = new TreeMap<>(Comparator.reverseOrder());

        DateTime time = DateTime.now();
        for (int minutes = 0; minutes < EXPIRATION_MINUTES; minutes++) {
            String minute = serializeInstant(time.toInstant());
            long count = getCount(ruleId, minute);
            if (count > 0)
                counts.put(minute, count);
            time = time.minusMinutes(1);
        }

        return counts;
    }

    @Required
    public void setJedisCluster(JedisCluster cluster) {
        this.cluster = cluster;
    }
}
