package ru.yandex.solomon.ydb;

import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;

import com.yandex.ydb.core.grpc.GrpcTransport;
import com.yandex.ydb.core.rpc.RpcTransport;
import com.yandex.ydb.table.SchemeClient;
import com.yandex.ydb.table.TableClient;
import com.yandex.ydb.table.rpc.grpc.GrpcSchemeRpc;
import com.yandex.ydb.table.rpc.grpc.GrpcTableRpc;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.opentracing.contrib.grpc.ActiveSpanContextSource;
import io.opentracing.contrib.grpc.ActiveSpanSource;
import io.opentracing.contrib.grpc.TracingClientInterceptor;
import io.opentracing.contrib.grpc.TracingClientInterceptor.ClientRequestAttribute;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.cloud.token.IamTokenClient;
import ru.yandex.discovery.DiscoveryServices;
import ru.yandex.grpc.utils.client.interceptors.MetricClientInterceptor;
import ru.yandex.monlib.metrics.MetricSupplier;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.config.DataSizeConverter;
import ru.yandex.solomon.config.TimeConverter;
import ru.yandex.solomon.config.protobuf.TKikimrClientConfig;
import ru.yandex.solomon.config.thread.ThreadPoolProvider;
import ru.yandex.solomon.secrets.SecretProvider;
import ru.yandex.solomon.util.NettyUtils;

/**
 * @author Sergey Polovko
 */
public class YdbClients implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(YdbClients.class);

    private final String configPath;
    private final String schemaRoot;
    private final RpcTransport transport;
    private final SchemeClient schemeClient;
    private final TableClient tableClient;

    public YdbClients(
        String configPath,
        TKikimrClientConfig config,
        ThreadPoolProvider threads,
        MetricRegistry registry,
        Optional<IamTokenClient> iamTokenClient,
        SecretProvider secrets)
    {
        this(configPath, config, threads, registry,
                new YdbAuthProviders(iamTokenClient.orElse(null), threads, secrets));
    }

    public YdbClients(
            String configPath,
            TKikimrClientConfig config,
            ThreadPoolProvider threadPoolProvider,
            MetricRegistry registry,
            YdbAuthProviders authProviders)
    {
        this.configPath = configPath;
        this.schemaRoot = config.getSchemaRoot();
        this.transport = makeTransport(configPath, config, threadPoolProvider, registry, authProviders);
        this.schemeClient = SchemeClient.newClient(GrpcSchemeRpc.useTransport(transport)).build();
        this.tableClient = makeTableClient(transport, config);
    }

    public SchemeClient getSchemeClient() {
        return schemeClient;
    }

    public TableClient getTableClient() {
        return tableClient;
    }

    public RpcTransport getTransport() {
        return transport;
    }

    public String getSchemaRoot() {
        return schemaRoot;
    }

    public MetricSupplier getTableClientMetrics() {
        return new YdbTableClientMetrics(tableClient, configPath);
    }

    private static RpcTransport makeTransport(
        String configPath,
        TKikimrClientConfig kikimrClientConfig,
        ThreadPoolProvider threadPoolProvider,
        MetricRegistry registry,
        YdbAuthProviders authProviders)
    {
        final GrpcTransport.Builder transportBuilder;
        final ExecutorService callExecutor;
        final int connectTimeoutMillis;
        final int maxMessageSize;
        final Duration readTimeout;

        if (kikimrClientConfig.hasGrpcConfig()) {
            var grpcConfig = kikimrClientConfig.getGrpcConfig();

            connectTimeoutMillis = Math.toIntExact(TimeConverter.protoToDuration(grpcConfig.getConnectTimeout()).toMillis());
            maxMessageSize = DataSizeConverter.toBytesInt(grpcConfig.getMaxInboundMessageSize());
            readTimeout = TimeConverter.protoToDuration(grpcConfig.getReadTimeout());

            callExecutor = threadPoolProvider.getExecutorService(
                grpcConfig.getThreadPoolName(),
                configPath + ".GrpcConfig.ThreadPoolName");

            transportBuilder = GrpcTransport.forHosts(DiscoveryServices.resolve(grpcConfig.getAddressesList()));
        } else if (kikimrClientConfig.hasConnectionConfig()) {
            var connConfig = kikimrClientConfig.getConnectionConfig();

            connectTimeoutMillis = Math.toIntExact(TimeConverter.protoToDuration(connConfig.getConnectTimeout()).toMillis());
            maxMessageSize = DataSizeConverter.toBytesInt(connConfig.getMaxInboundMessageSize());
            readTimeout = TimeConverter.protoToDuration(connConfig.getReadTimeout());

            callExecutor = threadPoolProvider.getExecutorService(
                connConfig.getThreadPoolName(),
                configPath + ".ConnectionConfig.ThreadPoolName");

            // configure endpoint and secure/insecure connection
            String endpoint = connConfig.getEndpoint();
            if (YdbEndpoint.isSecure(connConfig.getEndpoint())) {
                endpoint = YdbEndpoint.removeScheme(endpoint);
                transportBuilder = GrpcTransport.forEndpoint(endpoint, connConfig.getDatabase())
                        .withSecureConnection();
            } else {
                endpoint = YdbEndpoint.removeScheme(endpoint);
                transportBuilder = GrpcTransport.forEndpoint(endpoint, connConfig.getDatabase());
            }

            // configure authentication (TVM/IAM/None)
            if (connConfig.hasTvmAuth()) {
                transportBuilder.withAuthProvider(authProviders.tvm(connConfig.getTvmAuth()));
            } else if (connConfig.hasIamKeyAuth()) {
                transportBuilder.withAuthProvider(authProviders.iam(connConfig.getIamKeyAuth()));
            } else if (connConfig.hasIamKeyJson()) {
                transportBuilder.withAuthProvider(authProviders.iamKeyJson(connConfig.getIamKeyJson()));
            }
        } else {
            throw new IllegalStateException("empty GrpcConfig and ConnectionConfig in " + configPath);
        }

        final EventLoopGroup ioExecutor = threadPoolProvider.getIOExecutor();

        return transportBuilder
            .withReadTimeout(readTimeout)
            .withCallExecutor(callExecutor)
            .withChannelInitializer(channel -> {
                channel.offloadExecutor(callExecutor);
                channel.channelType(NettyUtils.clientChannelTypeForEventLoop(ioExecutor));
                channel.eventLoopGroup(ioExecutor);
                channel.maxInboundMessageSize(maxMessageSize);
                channel.withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
                channel.withOption(ChannelOption.TCP_NODELAY, Boolean.TRUE);
                channel.withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutMillis);
                channel.withOption(ChannelOption.SO_SNDBUF, 10 << 20); // 10 MiB
                channel.withOption(ChannelOption.SO_RCVBUF, 10 << 20); // 10 MiB
                channel.flowControlWindow(10 << 20); // 10 MiB
                channel.keepAliveTime(1, TimeUnit.MINUTES);
                channel.keepAliveTimeout(30, TimeUnit.SECONDS);
                channel.keepAliveWithoutCalls(true);
                channel.intercept(new MetricClientInterceptor("solomon", registry));
                channel.intercept(TracingClientInterceptor.newBuilder()
                        .withActiveSpanSource(ActiveSpanSource.GRPC_CONTEXT)
                        .withActiveSpanContextSource(ActiveSpanContextSource.GRPC_CONTEXT)
                        .withTracedAttributes(ClientRequestAttribute.HEADERS)
                        .withStreaming()
                        .build());
            })
            .build();
    }

    private static TableClient makeTableClient(RpcTransport transport, TKikimrClientConfig config) {
        int queryCacheSize = config.getQueryCacheSize() > 0 ? config.getQueryCacheSize() : 1000;
        int poolMinSize = config.getSessionPoolMinSize() > 0 ? config.getSessionPoolMinSize() : 10;
        int poolMaxSize = config.getSessionPoolMaxSize() > 0 ? config.getSessionPoolMaxSize() : 1000;
        return TableClient.newClient(GrpcTableRpc.useTransport(transport))
            .queryCacheSize(queryCacheSize)
            .sessionPoolSize(poolMinSize, poolMaxSize)
            .build();
    }

    @Override
    public void close() {
        tableClient.close();
        schemeClient.close();
        transport.close();
    }
}
