package ru.yandex.crypta.clients.redis;

import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;

import javax.inject.Inject;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisSentinelPool;
import redis.clients.jedis.ScanParams;
import redis.clients.jedis.ScanResult;
import redis.clients.jedis.Transaction;
import redis.clients.jedis.params.SetParams;

import ru.yandex.crypta.lib.proto.TRedisConfig;

import static redis.clients.jedis.params.SetParams.setParams;

public class JedisRedisClient implements RedisClient {

    private static final Logger LOG = LoggerFactory.getLogger(JedisRedisClient.class);

    private final TRedisConfig config;
    private JedisSentinelPool pool;

    @Inject
    public JedisRedisClient(TRedisConfig config) {
        this.config = config;
    }

    private void withJedis(Consumer<Jedis> action) {
        maybeConnect();

        try (Jedis jedis = pool.getResource()) {
            action.accept(jedis);
        }

    }

    private <T> T withJedis(Function<Jedis, T> action) {
        maybeConnect();

        try (Jedis jedis = pool.getResource()) {
            return action.apply(jedis);
        }
    }

    private synchronized void maybeConnect() {
        if (pool == null) {
            // lazy creation to avoid failed injection
            Set<String> sentinels = Set.of(config.getSentinels().split(","));
            LOG.info("Establish a connection to {} via sentinels {}",
                    config.getMastername(),
                    config.getSentinels()
            );
            pool = new JedisSentinelPool(
                    config.getMastername(),
                    sentinels,
                    config.getPassword()
            );
        }
    }

    private String getPrefix(Class<?> clazz, String entity) {
        return getKey(clazz.getSimpleName(), entity);
    }

    private String getPrefix(Class<?> clazz) {
        return getKey(clazz.getSimpleName());
    }


    private String getKey(Class<?> clazz, String entity, String field) {
        return getKey(clazz.getSimpleName(), entity, field);
    }

    private String getKey(String... parts) {
        return String.join(":", parts);
    }

    @Override
    public void set(Class<?> clazz, String entity, String field, String value) {
        String fullKey = getKey(clazz, entity, field);
        withJedis(jedis -> {
            // todo: check code reply?
            jedis.set(fullKey, value);
        });
    }

    @Override
    public String get(Class<?> clazz, String entity, String field) {
        String fullKey = getKey(clazz, entity, field);
        return withJedis(jedis -> {
            return jedis.get(fullKey);
        });
    }

    @Override
    public void setBytes(Class<?> clazz, String entity, String field, byte[] value, Duration ttl) {
        // jedis uses byte key to set byte values
        byte[] fullKey = getKey(clazz, entity, field).getBytes(StandardCharsets.UTF_8);
        withJedis(jedis -> {
            SetParams ttlParam = setParams().ex((int) ttl.toSeconds());
            // todo: check code reply?
            jedis.set(fullKey, value, ttlParam);
        });
    }

    @Override
    public byte[] getBytes(Class<?> clazz, String entity, String field) {
        // jedis uses byte key to set byte values
        byte[] fullKey = getKey(clazz, entity, field).getBytes(StandardCharsets.UTF_8);
        return withJedis(jedis -> {
            return jedis.get(fullKey);
        });
    }

    @Override
    public List<String> listKeys(Class<?> clazz, String entity) {
        String prefix = getPrefix(clazz, entity);
        return listKeys(prefix);
    }

    @Override
    public List<String> listKeys(Class<?> clazz) {
        String prefix = getPrefix(clazz);
        return listKeys(prefix);
    }

    private List<String> listKeys(String prefix) {
        return withJedis(jedis -> {
            ScanParams scanParams = new ScanParams().match(prefix + "*");

            ArrayList<String> result = new ArrayList<>();

            ScanResult<String> scanResult;
            String cursor = ScanParams.SCAN_POINTER_START;
            do {
                scanResult = jedis.scan(cursor, scanParams);
                cursor = scanResult.getCursor();

                result.addAll(scanResult.getResult());
            } while (!scanResult.isCompleteIteration());

            return result;
        });
    }

    @Override
    public long getLong(Class<?> clazz, String entity, String field) {
        String fullKey = getKey(clazz, entity, field);
        return withJedis(jedis -> {
            String value = jedis.get(fullKey);
            return value == null ? 0 : Long.parseLong(value);
        });
    }

    @Override
    public Long incr(Class<?> clazz, String entity, String field) {
        String fullKey = getKey(clazz, entity, field);
        return withJedis(jedis -> {
            return jedis.incr(fullKey);
        });
    }

    @Override
    public Long decr(Class<?> clazz, String entity, String field) {
        String fullKey = getKey(clazz, entity, field);
        return withJedis(jedis -> {
            return jedis.decr(fullKey);
        });
    }

    @Override
    public void addItem(Class<?> clazz, String entity, String field, byte[] item, int keepNItems) {
        withJedis(jedis -> {
            Transaction multi = jedis.multi();

            var key = getKey(clazz, entity, field).getBytes(StandardCharsets.UTF_8);
            multi.lpush(key, item);
            multi.ltrim(key, 0, keepNItems);
            multi.exec();
        });
    }

    @Override
    public List<byte[]> getItems(Class<?> clazz, String entity, String field, int itemsCount) {
        return withJedis(jedis -> {
            String key = getKey(clazz, entity, field);
            return jedis.lrange(key.getBytes(StandardCharsets.UTF_8), 0, itemsCount);
        });
    }

}
