package ru.yandex.intranet.d.services.integration.providers.grpc;

import java.util.Objects;
import java.util.concurrent.TimeUnit;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.grpc.ClientInterceptor;
import io.grpc.ManagedChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import reactor.cache.CacheMono;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Signal;
import reactor.core.scheduler.Schedulers;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuple3;
import reactor.util.function.Tuples;

import ru.yandex.intranet.d.loaders.CacheKey;
import ru.yandex.intranet.d.loaders.providers.ProvidersLoader;
import ru.yandex.intranet.d.model.TenantId;
import ru.yandex.intranet.d.model.providers.ProviderModel;

/**
 * Grpc channel factory.
 *
 * @author Ruslan Kadriev <aqru@yandex-team.ru>
 * @since 05.11.2020
 */
@Component
public class GrpcChannelFactoryService {
    private static final Logger LOGGER = LoggerFactory.getLogger(GrpcChannelFactoryService.class);
    private static final int MAX_TERMINATION_TIME = 3;
    private final ProvidersLoader providersLoader;
    private final GrpcChannelBuilderSupplier grpcChannelBuilderSupplier;
    private final Cache<CacheKey<String>, Tuple3<ManagedChannel, String, Boolean>> channelByProviderCache;

    public GrpcChannelFactoryService(ProvidersLoader providersLoader,
                                     GrpcChannelBuilderSupplier grpcChannelBuilderSupplier) {
        this.providersLoader = providersLoader;
        this.grpcChannelBuilderSupplier = grpcChannelBuilderSupplier;
        this.channelByProviderCache = CacheBuilder.newBuilder()
                .build();
    }

    public Mono<ManagedChannel> get(String id, TenantId tenantId) {
        Objects.requireNonNull(id, "Id must be provided.");
        Objects.requireNonNull(tenantId, "TenantId must be provided");

        CacheKey<String> key = new CacheKey<>(id, tenantId);
        return getProviderGrpcApiUriAndGrpcTlsOn(key)
                .flatMap(grpcApiUriAndGrpcTlsOn -> getById(key)
                        .flatMap(managedChannelTuple -> {
                            if (managedChannelTuple.getT2().equals(grpcApiUriAndGrpcTlsOn.getT1())
                                    && managedChannelTuple.getT3().equals(grpcApiUriAndGrpcTlsOn.getT2())) {
                                return Mono.just(managedChannelTuple.getT1());
                            }

                            return shutdownChannel(managedChannelTuple.getT1())
                                    .then(updateAndGetChannel(grpcApiUriAndGrpcTlsOn, key));
                        }));
    }

    public Mono<ManagedChannel> get(ProviderModel provider) {
        Objects.requireNonNull(provider, "Provider is required");
        CacheKey<String> key = new CacheKey<>(provider.getId(), provider.getTenantId());
        return Mono.just(provider)
                .map(p -> Tuples.of(p.getGrpcApiUri().orElseThrow(() ->
                                new IllegalArgumentException("Provider with key = [" + key + "] haven't Api Url!")),
                        p.isGrpcTlsOn()))
                .flatMap(grpcApiUriAndGrpcTlsOn -> getFromCache(key, provider).flatMap(managedChannelTuple -> {
                    if (managedChannelTuple.getT2().equals(grpcApiUriAndGrpcTlsOn.getT1())
                            && managedChannelTuple.getT3().equals(grpcApiUriAndGrpcTlsOn.getT2())) {
                        return Mono.just(managedChannelTuple.getT1());
                    }
                    return shutdownChannel(managedChannelTuple.getT1())
                            .then(updateAndGetChannel(grpcApiUriAndGrpcTlsOn, key));
                }));
    }

    private Mono<ManagedChannel> updateAndGetChannel(Tuple2<String, Boolean> grpcApiUriAndGrpcTlsOn,
                                                     CacheKey<String> key) {
        ManagedChannel managedChannel = buildManagedChannel(grpcApiUriAndGrpcTlsOn,
                getMonitoringGrpcProviderStubInterceptor(key.getId(), key.getTenantId()), new RequestIdInterceptor());

        return putByIdToCache(key, Signal.next(Tuples.of(managedChannel, grpcApiUriAndGrpcTlsOn.getT1(),
                grpcApiUriAndGrpcTlsOn.getT2())))
                .thenReturn(managedChannel);
    }

    private Mono<Void> shutdownChannel(ManagedChannel channel) {
        return Mono.just(channel.shutdown())
                .publishOn(Schedulers.boundedElastic())
                .map(managedChannel -> awaitTermination(channel))
                .doOnError(e -> LOGGER.warn("Failed to terminate channel", e))
                .then()
                .onErrorResume(e -> Mono.empty());
    }

    private boolean awaitTermination(ManagedChannel channel) {
        try {
            channel.awaitTermination(MAX_TERMINATION_TIME, TimeUnit.SECONDS);
        } catch (InterruptedException ignored) {
            LOGGER.debug("GRPC channel factory termination was interrupted");
        }

        return channel.isTerminated();
    }

    private MonitoringGrpcProviderStubInterceptor getMonitoringGrpcProviderStubInterceptor(String id,
                                                                                           TenantId tenantId) {
        return new MonitoringGrpcProviderStubInterceptor(id, tenantId.getId());
    }

    private Mono<Tuple2<String, Boolean>> getProviderGrpcApiUriAndGrpcTlsOn(CacheKey<String> key) {
        return providersLoader.getProviderByIdImmediate(key.getId(), key.getTenantId())
                .map(providerModel -> Tuples.of(providerModel.orElseThrow(() -> new IllegalArgumentException(
                                "No such provider with key = [" + key + "]!"))
                                .getGrpcApiUri()
                                .orElseThrow(() -> new IllegalArgumentException(
                                        "Provider with key = [" + key + "] haven't Api Url!")),
                        providerModel.get().isGrpcTlsOn())
                );
    }

    private ManagedChannel buildManagedChannel(Tuple2<String, Boolean> grpcApiUriAndGrpcTlsOn,
                                               ClientInterceptor... interceptors) {
        var grpcChannelBuilder = grpcChannelBuilderSupplier.get(grpcApiUriAndGrpcTlsOn.getT1())
                .intercept(interceptors);

        if (grpcApiUriAndGrpcTlsOn.getT2()) {
            grpcChannelBuilder.useTransportSecurity();
        } else {
            grpcChannelBuilder.usePlaintext();
        }

        return grpcChannelBuilder
                .build();
    }

    private Mono<Tuple3<ManagedChannel, String, Boolean>> getById(CacheKey<String> key) {
        return CacheMono.lookup(this::getFromCacheById, key)
                .onCacheMissResume(() -> loadById(key))
                .andWriteWith(this::putByIdToCache);
    }

    private Mono<Tuple3<ManagedChannel, String, Boolean>> getFromCache(CacheKey<String> key, ProviderModel provider) {
        return CacheMono.lookup(this::getFromCacheById, key)
                .onCacheMissResume(() -> createChannel(key, provider))
                .andWriteWith(this::putByIdToCache);
    }

    private Mono<Signal<? extends Tuple3<ManagedChannel, String, Boolean>>> getFromCacheById(CacheKey<String> key) {
        var value = channelByProviderCache.getIfPresent(key);
        if (value != null) {
            return Mono.just(Signal.next(value));
        }
        return Mono.empty();
    }

    private Mono<Tuple3<ManagedChannel, String, Boolean>> loadById(CacheKey<String> key) {
        return getProviderGrpcApiUriAndGrpcTlsOn(key)
                .map(grpcApiUriAndGrpcTlsOn -> Tuples.of(buildManagedChannel(grpcApiUriAndGrpcTlsOn,
                        getMonitoringGrpcProviderStubInterceptor(key.getId(), key.getTenantId()),
                        new RequestIdInterceptor()), grpcApiUriAndGrpcTlsOn.getT1(),
                        grpcApiUriAndGrpcTlsOn.getT2()));
    }

    private Mono<Tuple3<ManagedChannel, String, Boolean>> createChannel(CacheKey<String> key, ProviderModel provider) {
        return Mono.just(provider)
                .map(p -> Tuples.of(p.getGrpcApiUri().orElseThrow(() ->
                                new IllegalArgumentException("Provider with key = [" + key + "] haven't Api Url!")),
                        p.isGrpcTlsOn()))
                .map(grpcApiUriAndGrpcTlsOn -> Tuples.of(buildManagedChannel(grpcApiUriAndGrpcTlsOn,
                        getMonitoringGrpcProviderStubInterceptor(key.getId(), key.getTenantId()),
                        new RequestIdInterceptor()), grpcApiUriAndGrpcTlsOn.getT1(),
                        grpcApiUriAndGrpcTlsOn.getT2()));
    }

    private Mono<Void> putByIdToCache(CacheKey<String> key,
                                      Signal<? extends Tuple3<ManagedChannel, String, Boolean>> value) {
        return Mono.fromRunnable(() -> {
            if (!value.hasValue()) {
                return;
            }
            Tuple3<ManagedChannel, String, Boolean> managedChannelTuple = value.get();
            if (managedChannelTuple != null) {
                channelByProviderCache.put(key, managedChannelTuple);
            }
        });
    }
}
