package ru.yandex.intranet.d.web.security.tvm;

import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.cache.CacheMono;
import reactor.core.Exceptions;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Signal;
import reactor.util.retry.Retry;
import reactor.util.retry.RetrySpec;

import ru.yandex.intranet.d.web.security.tvm.model.TvmTicket;

/**
 * Service tickets cache.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
public class ServiceTicketsCache {

    private final long ownTvmId;
    private final long maxRetryAttempts;
    private final Duration fixedRetryDelay;
    private final BiFunction<String, List<String>, Mono<Map<String, TvmTicket>>> ticketSupplier;
    private final Cache<Long, String> serviceTicketCache;

    public ServiceTicketsCache(long ownTvmId,
                               Duration expireAfterWrite,
                               long maximumSize,
                               long maxRetryAttempts,
                               Duration fixedRetryDelay,
                               BiFunction<String, List<String>, Mono<Map<String, TvmTicket>>> ticketSupplier) {
        this.ownTvmId = ownTvmId;
        this.maxRetryAttempts = maxRetryAttempts;
        this.fixedRetryDelay = fixedRetryDelay;
        this.ticketSupplier = ticketSupplier;
        this.serviceTicketCache = CacheBuilder.newBuilder()
                .expireAfterWrite(expireAfterWrite.toMillis(), TimeUnit.MILLISECONDS)
                .maximumSize(maximumSize)
                .build();
    }

    public Mono<String> getServiceTicket(long targetTvmId) {
        return CacheMono.lookup(this::getServiceTicketFromCache, targetTvmId)
                .onCacheMissResume(() -> getServiceTicketWithRetries(targetTvmId))
                .andWriteWith(this::putServiceTicketToCache);
    }

    private Mono<Signal<? extends String>> getServiceTicketFromCache(long k) {
        return Mono.justOrEmpty(serviceTicketCache.getIfPresent(k)).map(Signal::next);
    }

    private Mono<String> getServiceTicketWithRetries(long targetTvmId) {
        return Mono.defer(() -> ticketSupplier.apply(String.valueOf(ownTvmId), List.of(String.valueOf(targetTvmId))))
                .retryWhen(retryServiceTicket()).flatMap(tickets -> {
                    Optional<String> ticket = tickets.values().stream()
                            .filter(t -> t.getTvmId() == targetTvmId && t.getTicket().isPresent())
                            .findFirst().flatMap(TvmTicket::getTicket);
                    if (ticket.isEmpty()) {
                        return Mono.error(new IllegalStateException(
                                "TVM ticket for destination " + targetTvmId + " is missing"));
                    }
                    return Mono.just(ticket.get());
                });
    }

    private Mono<Void> putServiceTicketToCache(long key, Signal<? extends String> value) {
        return Mono.fromRunnable(() -> {
            if (!value.hasValue()) {
                return;
            }
            String ticket = value.get();
            if (ticket != null) {
                serviceTicketCache.put(key, ticket);
            }
        });
    }

    private Retry retryServiceTicket() {
        return RetrySpec.fixedDelay(maxRetryAttempts, fixedRetryDelay).filter(e -> {
            if (e instanceof WebClientResponseException webException) {
                return webException.getRawStatusCode() >= 500 && webException.getRawStatusCode() < 600;
            }
            return !Exceptions.isRetryExhausted(e);
        });
    }

}
