package ru.yandex.solomon.core.kikimr;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;

import com.google.common.net.HostAndPort;
import com.yandex.ydb.core.grpc.GrpcDiscoveryRpc;
import com.yandex.ydb.core.grpc.GrpcTransport;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelOption;
import io.opentracing.contrib.grpc.ActiveSpanContextSource;
import io.opentracing.contrib.grpc.ActiveSpanSource;
import io.opentracing.contrib.grpc.TracingClientInterceptor;
import io.opentracing.contrib.grpc.TracingClientInterceptor.ClientRequestAttribute;

import ru.yandex.discovery.DiscoveryService;
import ru.yandex.grpc.utils.client.interceptors.MetricClientInterceptor;
import ru.yandex.kikimr.client.KikimrGrpcTransport;
import ru.yandex.kikimr.client.KikimrTransport;
import ru.yandex.kikimr.client.discovery.Discovery;
import ru.yandex.kikimr.client.discovery.NodeDiscovery;
import ru.yandex.kikimr.client.discovery.YdbDiscovery;
import ru.yandex.kikimr.client.kv.KikimrV2NodeRouter;
import ru.yandex.kikimr.client.kv.noderesolver.KikmirV2NodeResolverImpl;
import ru.yandex.kikimr.client.kv.transport.GrpcNodeFactoryImpl;
import ru.yandex.kikimr.client.kv.transport.GrpcTransportSettings;
import ru.yandex.kikimr.client.kv.transport.NodeFactory;
import ru.yandex.kikimr.grpc.GrpcOptions;
import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.config.DataSizeConverter;
import ru.yandex.solomon.config.TimeConverter;
import ru.yandex.solomon.config.protobuf.TConnectionConfig;
import ru.yandex.solomon.config.protobuf.TKikimrClientConfig;
import ru.yandex.solomon.config.protobuf.Time;
import ru.yandex.solomon.config.protobuf.rpc.TGrpcClientConfig;
import ru.yandex.solomon.config.thread.ThreadPoolProvider;
import ru.yandex.solomon.util.NettyUtils;
import ru.yandex.solomon.ydb.YdbAuthProviders;
import ru.yandex.solomon.ydb.YdbEndpoint;

import static ru.yandex.solomon.config.OptionalSet.setTime;
import static ru.yandex.solomon.config.TimeUnitConverter.protoToUnit;


/**
 * @author Sergey Polovko
 */
public class KikimrTransportFactory {

    private final String clientName;
    private final ThreadPoolProvider threadPoolProvider;
    private final MetricRegistry registry;

    public KikimrTransportFactory(String clientName, ThreadPoolProvider threadPoolProvider, MetricRegistry registry) {
        this.clientName = clientName;
        this.threadPoolProvider = threadPoolProvider;
        this.registry = registry;
    }

    public KikimrTransport newTransport(TKikimrClientConfig config) {
        if (config.hasGrpcConfig()) {
            return newTransport(config.getGrpcConfig());
        }

        if (config.hasConnectionConfig()) {
            return newTransport(config.getConnectionConfig());
        }

        throw new IllegalStateException("trying to create Kikimr client (" + clientName + ") without proper GrpcConfig");
    }

    public KikimrV2NodeRouter newNodeRouter(TKikimrClientConfig tKikimrClientConfig, YdbAuthProviders authProviders) {
        if (!tKikimrClientConfig.hasConnectionConfig()) {
            throw new IllegalStateException();
        }
        var config = tKikimrClientConfig.getConnectionConfig();
        var opts = newOptions(config);
        var discovery = newDiscovery(config);
        var nodeFactory = newNodeFactory(config, opts, authProviders);

        var nodeResolver = new KikmirV2NodeResolverImpl(discovery, opts.callExecutor, nodeFactory);
        return new KikimrV2NodeRouter(nodeResolver);

    }

    private NodeFactory newNodeFactory(TConnectionConfig config, GrpcOptions opts, YdbAuthProviders authProviders) {
        var settings = GrpcTransportSettings.newBuilder()
                .endpoint(config.getEndpoint())
                .database(config.getDatabase());

        var readTimeout = TimeConverter.protoToDuration(config.getReadTimeout());
        if (!readTimeout.isZero()) {
            settings.readTimeout(readTimeout);
        }
        settings.callExecutor(opts.callExecutor);

        if (config.hasTvmAuth()) {
            settings.authProvider(authProviders.tvm(config.getTvmAuth()));
        } else if (config.hasIamKeyAuth()) {
            settings.authProvider(authProviders.iam(config.getIamKeyAuth()));
        } else if (config.hasIamKeyJson()) {
            settings.authProvider(authProviders.iamKeyJson(config.getIamKeyJson()));
        }

        if (config.hasUseTLS() && config.getUseTLS()) {
            settings.secureConnection();
        }


        settings.channelInitializer(channel -> {
            channel.offloadExecutor(opts.callExecutor);
            channel.channelType(NettyUtils.clientChannelTypeForEventLoop(opts.ioExecutor));
            channel.eventLoopGroup(opts.ioExecutor);
            channel.maxInboundMessageSize(DataSizeConverter.toBytesInt(config.getMaxInboundMessageSize()));
            channel.withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
            channel.withOption(ChannelOption.TCP_NODELAY, Boolean.TRUE);
            channel.withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS,
                    Math.toIntExact(TimeConverter.protoToDuration(config.getConnectTimeout()).toMillis()));
            channel.withOption(ChannelOption.SO_SNDBUF, 10 << 20); // 10 MiB
            channel.withOption(ChannelOption.SO_RCVBUF, 10 << 20); // 10 MiB
            channel.flowControlWindow(10 << 20); // 10 MiB
            channel.keepAliveTime(1, TimeUnit.MINUTES);
            channel.keepAliveTimeout(30, TimeUnit.SECONDS);
            channel.keepAliveWithoutCalls(true);
            channel.intercept(new MetricClientInterceptor("solomon", registry));
        });
        return new GrpcNodeFactoryImpl(settings.build());
    }

    private KikimrTransport newTransport(TGrpcClientConfig config) {
        var discovery = newDiscovery(config);
        var opts = newOptions(config);
        return new KikimrGrpcTransport(discovery, opts);
    }

    private KikimrTransport newTransport(TConnectionConfig config) {
        var discovery = newDiscovery(config);
        var opts = newOptions(config);
        return new KikimrGrpcTransport(discovery, opts);
    }

    private Discovery newDiscovery(TGrpcClientConfig config) {
        var executor = threadPoolProvider.getExecutorService(
                config.getThreadPoolName(),
                clientName + ".GrpcConfig.ThreadPoolName");
        var timer = threadPoolProvider.getSchedulerExecutorService();
        return new SolomonDiscovery(DiscoveryService.async(), config.getAddressesList(), executor, timer);
    }

    private NodeDiscovery newDiscovery(TConnectionConfig config) {
        var endpoint = config.getEndpoint();
        var address = HostAndPort.fromString(YdbEndpoint.removeScheme(endpoint));
        var builder = GrpcTransport.forHost(address.getHost(), address.getPort());

        if (YdbEndpoint.isSecure(endpoint)) {
            builder.withSecureConnection();
        }

        var executor = threadPoolProvider.getExecutorService(
                config.getThreadPoolName(),
                clientName + ".ConnectionConfig.ThreadPoolName");

        builder.withCallExecutor(executor);
        setTime(builder::withReadTimeout, config.getReadTimeout());
        builder.withChannelInitializer(channel -> {
            channel.offloadExecutor(executor);
            channel.channelType(NettyUtils.clientChannelTypeForEventLoop(threadPoolProvider.getIOExecutor()));
            channel.eventLoopGroup(threadPoolProvider.getIOExecutor());
            channel.withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
            channel.withOption(ChannelOption.TCP_NODELAY, Boolean.TRUE);
            channel.withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) TimeConverter.protoToDuration(config.getConnectTimeout()).toMillis());
            channel.keepAliveTime(1, TimeUnit.MINUTES);
            channel.keepAliveTimeout(30, TimeUnit.SECONDS);
            channel.keepAliveWithoutCalls(true);
            channel.intercept(new MetricClientInterceptor("solomon", registry));
            channel.intercept(TracingClientInterceptor.newBuilder()
                    .withActiveSpanSource(ActiveSpanSource.GRPC_CONTEXT)
                    .withActiveSpanContextSource(ActiveSpanContextSource.GRPC_CONTEXT)
                    .withTracedAttributes(ClientRequestAttribute.HEADERS)
                    .withStreaming()
                    .build());
        });

        var transport = builder.build();
        var rpc = new GrpcDiscoveryRpc(transport);
        return new YdbDiscovery(rpc, config.getDatabase(), executor, threadPoolProvider.getIOExecutor());
    }

    private GrpcOptions newOptions(TGrpcClientConfig config) {
        ExecutorService callExecutor = threadPoolProvider.getExecutorService(
                config.getThreadPoolName(),
                clientName + ".ConnectionConfig.ThreadPoolName");

        Time keepAliveTime = config.getKeepAliveTime();
        Time keepAliveTimeout = config.getKeepAliveTimeout();

        return fillDefaultOpts(GrpcOptions.newBuilder())
                .maxInboundMessageSize(DataSizeConverter.toBytesInt(config.getMaxInboundMessageSize()))
                .withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(TimeConverter.protoToDuration(config.getConnectTimeout()).toMillis()))
                .readTimeout(TimeConverter.protoToDuration(config.getReadTimeout()).toMillis(), TimeUnit.MILLISECONDS)
                .keepAliveTime(keepAliveTime.getValue(), protoToUnit(keepAliveTime.getUnit()))
                .keepAliveTimeout(keepAliveTimeout.getValue(), protoToUnit(keepAliveTimeout.getUnit()))
                .callExecutor(callExecutor)
                .build();
    }

    private GrpcOptions newOptions(TConnectionConfig config) {
        ExecutorService callExecutor = threadPoolProvider.getExecutorService(
                config.getThreadPoolName(),
                clientName + ".GrpcConfig.ThreadPoolName");

        Time keepAliveTime = config.getKeepAliveTime();
        Time keepAliveTimeout = config.getKeepAliveTimeout();

        return fillDefaultOpts(GrpcOptions.newBuilder())
                .maxInboundMessageSize(DataSizeConverter.toBytesInt(config.getMaxInboundMessageSize()))
                .withOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(TimeConverter.protoToDuration(config.getConnectTimeout()).toMillis()))
                .readTimeout(TimeConverter.protoToDuration(config.getReadTimeout()).toMillis(), TimeUnit.MILLISECONDS)
                .keepAliveTime(keepAliveTime.getValue(), protoToUnit(keepAliveTime.getUnit()))
                .keepAliveTimeout(keepAliveTimeout.getValue(), protoToUnit(keepAliveTimeout.getUnit()))
                .callExecutor(callExecutor)
                .build();
    }

    public GrpcOptions.Builder fillDefaultOpts(GrpcOptions.Builder builder) {
        return builder
                .withOption(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT)
                .withOption(ChannelOption.TCP_NODELAY, Boolean.TRUE)
                .withOption(ChannelOption.SO_SNDBUF, 10 << 20) // 10 MiB
                .withOption(ChannelOption.SO_RCVBUF, 10 << 20) // 10 MiB
                .flowControlWindow(10 << 20) // 10 MiB
                .eventLoopGroup(threadPoolProvider.getIOExecutor())
                .timer(threadPoolProvider.getSchedulerExecutorService())
                .channelInitializer(b -> b.intercept(new MetricClientInterceptor("solomon", registry)));
    }
}
