package ru.yandex.intranet.d.services.integration.providers.security;

import java.time.Duration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.cache.CacheMono;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Signal;
import reactor.util.retry.Retry;
import reactor.util.retry.RetrySpec;

import ru.yandex.intranet.d.web.security.tvm.TvmClient;
import ru.yandex.intranet.d.web.security.tvm.model.TvmTicket;

/**
 * Provider API authentication supplier implementation.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@Component
@Profile({"dev", "testing", "production", "load-testing"})
public class ProviderAuthSupplierImpl implements ProviderAuthSupplier {

    private static final Logger LOG = LoggerFactory.getLogger(ProviderAuthSupplierImpl.class);

    private static final long REFRESH_DELAY_HOURS = 1L;
    private static final long REFRESH_DELAY_ON_ERROR_MINUTES = 1L;

    private final TvmClient tvmClient;
    private final long ownTvmId;
    private final Cache<Long, String> ticketCache;
    private final ScheduledExecutorService scheduler;

    public ProviderAuthSupplierImpl(TvmClient tvmClient,
                                    @Value("${tvm.ownId}") long ownTvmId) {
        this.tvmClient = tvmClient;
        this.ownTvmId = ownTvmId;
        this.ticketCache = CacheBuilder.newBuilder()
                .maximumSize(10)
                .build();
        ThreadFactory threadFactory = new ThreadFactoryBuilder()
                .setDaemon(true)
                .setNameFormat("providers-tvm-auth-pool-%d")
                .setUncaughtExceptionHandler((t, e) -> LOG.error("Uncaught exception in scheduler thread " + t, e))
                .build();
        ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = new ScheduledThreadPoolExecutor(2,
                threadFactory);
        scheduledThreadPoolExecutor.setRemoveOnCancelPolicy(true);
        this.scheduler = scheduledThreadPoolExecutor;
    }

    @Override
    public Mono<String> getTicket(long destinationId) {
        return CacheMono.lookup(this::lookup, destinationId)
                .onCacheMissResume(() -> resume(destinationId))
                .andWriteWith(this::write);
    }

    @PostConstruct
    @SuppressWarnings("FutureReturnValueIgnored")
    public void postConstruct() {
        LOG.info("Scheduling reload for provider API TVM tickets...");
        scheduler.schedule(this::reloadTickets, REFRESH_DELAY_HOURS, TimeUnit.HOURS);
        LOG.info("Scheduled reload for provider API TVM tickets");
    }

    @PreDestroy
    public void preDestroy() {
        LOG.info("Stopping provider API TVM auth supplier...");
        scheduler.shutdown();
        try {
            scheduler.awaitTermination(1, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        scheduler.shutdownNow();
        LOG.info("Stopped provider API TVM auth supplier");
    }

    private Mono<Signal<? extends String>> lookup(Long key) {
        String value = ticketCache.getIfPresent(key);
        if (value != null) {
            return Mono.just(Signal.next(value));
        }
        return Mono.empty();
    }

    private Mono<String> resume(Long key) {
        return tvmClient.tickets(String.valueOf(ownTvmId), List.of(String.valueOf(key)))
                .retryWhen(retryServiceTicket()).flatMap(tickets -> {
                    Optional<String> ticket = tickets.values().stream()
                            .filter(t -> key.equals(t.getTvmId()) && t.getTicket().isPresent())
                            .findFirst().flatMap(TvmTicket::getTicket);
                    if (ticket.isEmpty()) {
                        return Mono.error(new IllegalStateException(
                                "Unable to obtain TVM ticket for destination " + key));
                    }
                    return Mono.just(ticket.get());
                });
    }

    private Mono<Void> write(Long key, Signal<? extends String> value) {
        return Mono.fromRunnable(() -> {
            if (!value.hasValue()) {
                return;
            }
            String valueToPut = value.get();
            if (valueToPut != null) {
                ticketCache.put(key, valueToPut);
            }
        });
    }

    private Retry retryServiceTicket() {
        return RetrySpec.fixedDelay(3, Duration.ofMillis(100)).filter(e -> {
            if (e instanceof WebClientResponseException) {
                WebClientResponseException webException = (WebClientResponseException) e;
                return webException.getRawStatusCode() >= 500 && webException.getRawStatusCode() < 600;
            }
            return !(e instanceof IllegalStateException);
        });
    }

    @SuppressWarnings("FutureReturnValueIgnored")
    private void reloadTickets() {
        LOG.info("Reloading provider API TVM tickets...");
        boolean allTicketsLoaded = false;
        try {
            Map<Long, Optional<String>> tickets = loadTickets(new HashSet<>(ticketCache.asMap().keySet()));
            tickets.forEach((k, v) -> v.ifPresent(s -> ticketCache.put(k, s)));
            allTicketsLoaded = tickets.values().stream().allMatch(Optional::isPresent);
        } catch (Exception e) {
            LOG.error("Failed to reload all provider API TVM tickets", e);
        } finally {
            if (allTicketsLoaded) {
                LOG.info("All provider API TVM tickets successfully reloaded, rescheduling...");
                scheduler.schedule(this::reloadTickets, REFRESH_DELAY_HOURS, TimeUnit.HOURS);
            } else {
                LOG.info("Failed to reload all provider API TVM tickets, rescheduling...");
                scheduler.schedule(this::reloadTickets, REFRESH_DELAY_ON_ERROR_MINUTES, TimeUnit.MINUTES);
            }
        }
    }

    private Map<Long, Optional<String>> loadTickets(Set<Long> destinationIds) {
        if (destinationIds.isEmpty()) {
            return Map.of();
        }
        return tvmClient.tickets(String.valueOf(ownTvmId),
                        destinationIds.stream().map(Objects::toString).collect(Collectors.toList()))
                .retryWhen(retryServiceTicket()).map(tickets -> {
                    Map<Long, Optional<String>> result = new HashMap<>();
                    tickets.forEach((k, v) -> {
                        if (v.getError().isPresent()) {
                            LOG.error("Failed to get TVM ticket for {}: {}", k, v.getError().get());
                        }
                        if (v.getTicket().isPresent()) {
                            result.put(v.getTvmId(), v.getTicket());
                        }
                    });
                    destinationIds.forEach(id -> {
                        if (!result.containsKey(id)) {
                            result.put(id, Optional.empty());
                        }
                    });
                    return result;
                }).doOnError(e -> LOG.error("Failed to load provider API TVM tickets", e))
                .onErrorResume(e -> {
                    Map<Long, Optional<String>> result = new HashMap<>();
                    destinationIds.forEach(id -> result.put(id, Optional.empty()));
                    return Mono.just(result);
                }).block();
    }

}
