package ru.yandex.direct.core.units.storage;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import javax.annotation.Nullable;

import io.lettuce.core.KeyValue;
import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import one.util.streamex.StreamEx;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.retry.RetryPolicy;
import org.springframework.retry.policy.SimpleRetryPolicy;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.stereotype.Component;

import ru.yandex.direct.common.lettuce.LettuceConnectionProvider;
import ru.yandex.direct.tracing.Trace;
import ru.yandex.direct.tracing.TraceProfile;

import static ru.yandex.direct.common.configuration.RedisConfiguration.LETTUCE;

@Component(LettuceStorage.NAME)
public class LettuceStorage implements Storage {
    public static final String NAME = "LETTUCE_STORAGE";
    private static final Logger logger = LoggerFactory.getLogger(LettuceStorage.class);

    private final LettuceConnectionProvider connectionProvider;
    private final RetryTemplate retryTemplate;

    @Autowired
    public LettuceStorage(
            @Qualifier(LETTUCE) LettuceConnectionProvider connectionProvider) {
        this.connectionProvider = connectionProvider;

        RetryPolicy policy = new SimpleRetryPolicy(connectionProvider.getMaxAttempts(),
                Collections.singletonMap(RedisException.class, true));

        retryTemplate = new RetryTemplate();
        retryTemplate.setRetryPolicy(policy);
    }

    @Override
    public Integer get(String key) {
        String stringValue = redisProfiledCall("redis:get", cmd -> cmd.get(key));
        logger.debug("get({}): {}", key, stringValue);
        return convertValueToInteger(key, stringValue).orElse(0);
    }

    @Override
    public Map<String, Integer> getMulti(Collection<String> keysList) {
        logger.trace("getMulti(collection of size {})", keysList.size());
        final String[] keys = keysList.toArray(new String[0]);
        final List<KeyValue<String, String>> valuesList = redisProfiledCall("redis:getMulti", cmd -> cmd.mget(keys));

        // Fill result map with given keys and obtained values
        Map<String, Integer> result = StreamEx.of(valuesList)
                .mapToEntry(KeyValue::getKey, kv -> convertValueToInteger(kv.getKey(), kv.getValueOrElse(null)))
                .filterValues(Optional::isPresent)
                .mapValues(Optional::get)
                .toMap();
        logger.trace("getMulti calculated the result: {}", result);
        return result;
    }

    @Override
    public boolean set(String key, Integer value, int ttl) {
        logger.debug("setProp({}, {}, {})", key, value, ttl);
        return "OK".equals(redisProfiledCall("redis:set", cmd -> cmd.setex(key, ttl, Integer.toString(value))));
    }

    @Override
    public boolean incr(String key, Integer delta) {
        logger.debug("incr({}, {})", key, delta);
        return redisProfiledCall("redis:incr", cmd -> {
            if (cmd.exists(key) > 0) {
                cmd.incrby(key, delta);
                return true;
            }
            logger.warn("Can't increment value as key '{}' is absent", key);
            return false;
        });
    }

    @Override
    public boolean incrOrSet(String key, Integer delta, int ttl) {
        logger.debug("incrOrSet({}, {}, {})", key, delta, ttl);
        return redisProfiledCall("redis:incrOrSet", cmd -> {
            if (cmd.incrby(key, delta).intValue() == delta) {
                cmd.expire(key, ttl);
            }
            return true;
        });
    }

    @Override
    public Map<String, Integer> deleteMulti(Collection<String> keysList) {
        // Naive implementation. deleteMulti is used only during testing in Perl code
        logger.trace("deleteMulti(collection of size {})", keysList.size());
        Map<String, Integer> res = getMulti(keysList);
        final String[] keys = keysList.toArray(new String[keysList.size()]);
        Long keysRemoved = redisProfiledCall("redis:deleteMulti", cmd -> cmd.del(keys));
        if (keysRemoved != keysList.size()) {
            logger.info("Amount of deleted keys ({}) differs from size of given key setProp '{}'", keysRemoved, keys);
        }
        return res;
    }

    /**
     * @param key         ключ. Значение используется лишь для сообщения в cлучае ошибки
     * @param stringValue {@link String} значение, кторое требуется преобразовать в {@link Integer}
     * @return {@link Optional}&lt;{@link Integer}&gt;. В случае, если значение не нашлось, возвращается {@link Optional#empty()}
     * @throws IllegalArgumentException если во время разбора значения {@code stringValue} случилось {@link NumberFormatException}
     */
    private Optional<Integer> convertValueToInteger(String key, @Nullable String stringValue)
            throws IllegalArgumentException {
        try {
            return Optional.ofNullable(stringValue).map(Integer::valueOf);
        } catch (NumberFormatException e) {
            throw new IllegalArgumentException(
                    String.format("value stored by key '%s' is not numeric: '%s'", key, stringValue), e);
        }
    }

    /**
     * Connect to redis and execute function with profiling.
     *
     * @param name     Profiled method name
     * @param function Function to execute on call
     * @param <T>      Return type
     * @return profiled function result
     */
    <T> T redisProfiledCall(String name, Function<RedisAdvancedClusterCommands<String, String>, T> function) {
        try (TraceProfile ignored = Trace.current().profile(name)) {
            return retryTemplate.execute(context -> function.apply(connectionProvider.getConnection().sync()));
        } catch (Exception e) {
            if (e instanceof InterruptedException) {
                Thread.currentThread().interrupt();
            }
            throw new StorageErrorException(e);
        }
    }
}
