package ru.yandex.direct.common.lettuce;

import java.util.function.Function;

import io.lettuce.core.cluster.RedisClusterClient;
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.codec.ByteArrayCodec;

import ru.yandex.direct.tracing.Trace;
import ru.yandex.direct.tracing.TraceProfile;

public class LettuceConnectionProvider {
    private final String dbname;
    private final RedisClusterClient client;
    private final int maxAttempts;
    private volatile StatefulRedisClusterConnection<String, String> connectionInstance;
    private volatile StatefulRedisClusterConnection<byte[], byte[]> binaryConnectionInstance;

    public LettuceConnectionProvider(String dbname, RedisClusterClient client, int maxAttempts) {
        this.dbname = dbname;
        this.client = client;
        this.maxAttempts = maxAttempts;
    }

    public StatefulRedisClusterConnection<String, String> getConnection() {
        StatefulRedisClusterConnection<String, String> connection = connectionInstance;
        if (connection != null) {
            return connection;
        }

        synchronized (this) {
            if (connectionInstance == null) {
                connectionInstance = client.connect();
            }
            return connectionInstance;
        }
    }

    public StatefulRedisClusterConnection<byte[], byte[]> getBinaryConnection() {
        StatefulRedisClusterConnection<byte[], byte[]> connection = binaryConnectionInstance;
        if (connection != null) {
            return connection;
        }

        synchronized (this) {
            if (binaryConnectionInstance == null) {
                binaryConnectionInstance = client.connect(ByteArrayCodec.INSTANCE);
            }
            return binaryConnectionInstance;
        }
    }

    public int getMaxAttempts() {
        return maxAttempts;
    }

    public <T> T call(String name, Function<RedisAdvancedClusterCommands<String, String>, T> function) {
        try (TraceProfile ignored = Trace.current().profile(name, dbname)) {
            return function.apply(getConnection().sync());
        } catch (RuntimeException e) {
            throw new LettuceExecuteException(e);
        }
    }

    public <T> T callBinary(String name, Function<RedisAdvancedClusterCommands<byte[], byte[]>, T> function) {
        try (TraceProfile ignored = Trace.current().profile(name, dbname)) {
            return function.apply(getBinaryConnection().sync());
        } catch (RuntimeException e) {
            throw new LettuceExecuteException(e);
        }
    }
}
