package ru.yandex.intranet.d.services.integration.providers.grpc;

import java.util.Objects;
import java.util.function.Function;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.grpc.ManagedChannel;
import io.grpc.stub.AbstractAsyncStub;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component;
import reactor.cache.CacheMono;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Signal;

import ru.yandex.intranet.d.loaders.CacheKey;
import ru.yandex.intranet.d.model.TenantId;
import ru.yandex.intranet.d.model.providers.ProviderModel;

/**
 * Grpc stub factory.
 *
 * @author Ruslan Kadriev <aqru@yandex-team.ru>
 * @since 05.11.2020
 */
@Component
public final class GrpcStubFactoryService {

    private final GrpcChannelFactoryService grpcChannelFactoryService;
    private final Cache<Pair<Object, CacheKey<String>>, Pair<Object, ManagedChannel>> stubByProviderCache;

    public GrpcStubFactoryService(GrpcChannelFactoryService grpcChannelFactoryService) {
        this.grpcChannelFactoryService = grpcChannelFactoryService;
        this.stubByProviderCache = CacheBuilder.newBuilder()
                .build();
    }

    public <T extends AbstractAsyncStub<T>> Mono<T> getClient(Function<ManagedChannel, T> stubFactory,
                                                              String providerId, TenantId tenantId,
                                                              Class<T> typeParameterClass) {
        Objects.requireNonNull(stubFactory, "StubFactory must be provided.");
        Objects.requireNonNull(providerId, "Id must be provided.");
        Objects.requireNonNull(tenantId, "TenantId must be provided");

        CacheKey<String> cacheKey = new CacheKey<>(providerId, tenantId);
        Pair<Object, CacheKey<String>> key = Pair.of(typeParameterClass, cacheKey);
        return grpcChannelFactoryService.get(cacheKey.getId(), cacheKey.getTenantId())
                .flatMap(managedChannel -> getById(key, stubFactory)
                        .map(stubManagedChannelPair -> {
                            if (stubManagedChannelPair.getRight().equals(managedChannel)) {
                                return typeParameterClass.cast(stubManagedChannelPair.getLeft());
                            }

                            T stub = stubFactory.apply(managedChannel);
                            putByIdToCache(key, Signal.next(Pair.of(stub, managedChannel)));

                            return stub;
                        }));
    }

    public <T extends AbstractAsyncStub<T>> Mono<T> getClient(Function<ManagedChannel, T> stubFactory,
                                                              ProviderModel provider,
                                                              Class<T> clazz) {
        Objects.requireNonNull(stubFactory, "StubFactory is required.");
        Objects.requireNonNull(provider, "Provider is required.");
        Objects.requireNonNull(clazz, "Clazz is required.");
        CacheKey<String> cacheKey = new CacheKey<>(provider.getId(), provider.getTenantId());
        Pair<Object, CacheKey<String>> key = Pair.of(clazz, cacheKey);
        return grpcChannelFactoryService.get(provider)
                .flatMap(managedChannel -> getFromCache(key, provider, stubFactory).map(stubManagedChannelPair -> {
                    if (stubManagedChannelPair.getRight().equals(managedChannel)) {
                        return clazz.cast(stubManagedChannelPair.getLeft());
                    }
                    T stub = stubFactory.apply(managedChannel);
                    putByIdToCache(key, Signal.next(Pair.of(stub, managedChannel)));
                    return stub;
                }));
    }

    private <T extends AbstractAsyncStub<T>>
    Mono<Pair<Object, ManagedChannel>> getById(Pair<Object, CacheKey<String>> key,
                                               Function<ManagedChannel, T> stubFactory) {
        return CacheMono.lookup(this::getFromCacheById, key)
                .onCacheMissResume(() -> loadById(stubFactory, key))
                .andWriteWith(this::putByIdToCache);
    }

    private <T extends AbstractAsyncStub<T>> Mono<Pair<Object, ManagedChannel>> getFromCache(
            Pair<Object, CacheKey<String>> key, ProviderModel provider, Function<ManagedChannel, T> stubFactory) {
        return CacheMono.lookup(this::getFromCacheById, key)
                .onCacheMissResume(() -> createStub(stubFactory, provider))
                .andWriteWith(this::putByIdToCache);
    }

    private Mono<Signal<? extends Pair<Object, ManagedChannel>>> getFromCacheById(Pair<Object, CacheKey<String>> key) {
        var value = stubByProviderCache.getIfPresent(key);
        if (value != null) {
            return Mono.just(Signal.next(value));
        }
        return Mono.empty();
    }

    private <T extends AbstractAsyncStub<T>> Mono<Pair<Object, ManagedChannel>> loadById(
            Function<ManagedChannel, T> stubFactory, Pair<Object, CacheKey<String>> key) {
        CacheKey<String> cacheKey = key.getRight();
        return grpcChannelFactoryService.get(cacheKey.getId(), cacheKey.getTenantId())
                .map(managedChannel -> Pair.of(stubFactory.apply(managedChannel), managedChannel));
    }

    private <T extends AbstractAsyncStub<T>> Mono<Pair<Object, ManagedChannel>> createStub(
            Function<ManagedChannel, T> stubFactory, ProviderModel provider) {
        return grpcChannelFactoryService.get(provider)
                .map(managedChannel -> Pair.of(stubFactory.apply(managedChannel), managedChannel));
    }

    private Mono<Void> putByIdToCache(Pair<Object, CacheKey<String>> key,
                                      Signal<? extends Pair<Object, ManagedChannel>> value) {
        return Mono.fromRunnable(() -> {
            if (!value.hasValue()) {
                return;
            }
            Pair<Object, ManagedChannel> valueToPut = value.get();
            if (valueToPut != null) {
                stubByProviderCache.put(key, valueToPut);
            }
        });
    }
}
