package ru.yandex.intranet.d.datasource.security;

import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.core.publisher.Mono;
import reactor.util.retry.Retry;
import reactor.util.retry.RetrySpec;

import ru.yandex.intranet.d.web.security.tvm.TvmClient;
import ru.yandex.intranet.d.web.security.tvm.model.TvmTicket;

/**
 * YDB TVM auth provider.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@Component
@Profile({"dev", "testing", "production", "load-testing"})
public class YdbTvmAuthProvider implements YdbAuthProvider {

    private static final Logger LOG = LoggerFactory.getLogger(YdbTvmAuthProvider.class);

    private final TvmClient tvmClient;
    private final long tvmDestination;
    private final long ownTvmId;
    private final Cache<Long, String> ticketCache;
    private final ScheduledExecutorService scheduler;

    public YdbTvmAuthProvider(TvmClient tvmClient,
                              @Value("${ydb.tvmDestination}") long tvmDestination,
                              @Value("${tvm.ownId}") long ownTvmId) {
        this.tvmClient = tvmClient;
        this.tvmDestination = tvmDestination;
        this.ownTvmId = ownTvmId;
        this.ticketCache = CacheBuilder.newBuilder()
                .maximumSize(10)
                .build();
        ThreadFactory threadFactory = new ThreadFactoryBuilder()
                .setDaemon(true)
                .setNameFormat("ydb-tvm-auth-pool-%d")
                .setUncaughtExceptionHandler((t, e) -> LOG.error("Uncaught exception in scheduler thread " + t, e))
                .build();
        ScheduledThreadPoolExecutor scheduledThreadPoolExecutor = new ScheduledThreadPoolExecutor(2,
                threadFactory);
        scheduledThreadPoolExecutor.setRemoveOnCancelPolicy(true);
        this.scheduler = scheduledThreadPoolExecutor;
    }

    @Override
    public String getToken() {
        String ticket = ticketCache.getIfPresent(tvmDestination);
        if (ticket == null) {
            throw new IllegalStateException("YDB TVM ticket is missing");
        }
        return ticket;
    }

    @Override
    public void close() {
    }

    @Override
    public boolean isOk() {
        return ticketCache.getIfPresent(tvmDestination) != null;
    }

    @PostConstruct
    @SuppressWarnings("FutureReturnValueIgnored")
    public void postConstruct() {
        LOG.info("Preparing YDB TVM ticket...");
        boolean ticketLoaded = false;
        try {
            Optional<String> ticket = loadTicket();
            ticket.ifPresent(t -> ticketCache.put(tvmDestination, t));
            ticketLoaded = ticket.isPresent();
        } catch (Exception e) {
            LOG.error("Failed to reload YDB TVM ticket", e);
        } finally {
            if (ticketLoaded) {
                LOG.info("YDB TVM ticket successfully obtained");
                scheduler.schedule(this::reloadTicket, 1, TimeUnit.HOURS);
            } else {
                LOG.info("Failed ot obtain YDB TVM ticket, rescheduling...");
                scheduler.schedule(this::reloadTicket, 1, TimeUnit.MINUTES);
            }
        }
    }

    @PreDestroy
    public void preDestroy() {
        LOG.info("Stopping YDB TVM auth provider...");
        scheduler.shutdown();
        try {
            scheduler.awaitTermination(1, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        scheduler.shutdownNow();
        LOG.info("Stopped YDB TVM auth provider");
    }

    @SuppressWarnings("FutureReturnValueIgnored")
    private void reloadTicket() {
        LOG.info("Reloading YDB TVM ticket...");
        boolean ticketLoaded = false;
        try {
            Optional<String> ticket = loadTicket();
            ticket.ifPresent(t -> ticketCache.put(tvmDestination, t));
            ticketLoaded = ticket.isPresent();
        } catch (Exception e) {
            LOG.error("Failed to reload YDB TVM ticket", e);
        } finally {
            if (ticketLoaded) {
                LOG.info("YDB TVM ticket successfully obtained");
                scheduler.schedule(this::reloadTicket, 1, TimeUnit.HOURS);
            } else {
                LOG.info("Failed to obtain YDB TVM ticket, rescheduling...");
                scheduler.schedule(this::reloadTicket, 1, TimeUnit.MINUTES);
            }
        }
    }

    private Optional<String> loadTicket() {
        return tvmClient.tickets(String.valueOf(ownTvmId), List.of(String.valueOf(tvmDestination)))
                .retryWhen(retryServiceTicket()).flatMap(tickets -> {
            Optional<String> ticket = tickets.values().stream()
                    .filter(t -> t.getTvmId() == tvmDestination && t.getTicket().isPresent())
                    .findFirst().flatMap(TvmTicket::getTicket);
            if (ticket.isEmpty()) {
                return Mono.error(new IllegalStateException(
                        "TVM ticket for YDB destination is missing"));
            }
            return Mono.just(ticket.get());
        }).map(Optional::ofNullable)
                .doOnError(e -> LOG.error("Failed to load YDB TVM ticket", e))
                .onErrorReturn(Optional.empty()).block();
    }

    private Retry retryServiceTicket() {
        return RetrySpec.fixedDelay(3, Duration.ofMillis(100)).filter(e -> {
            if (e instanceof WebClientResponseException) {
                WebClientResponseException webException = (WebClientResponseException) e;
                return webException.getRawStatusCode() >= 500 && webException.getRawStatusCode() < 600;
            }
            return !(e instanceof IllegalStateException);
        });
    }

}
