package ru.yandex.intranet.d.web.security.tvm;

import java.time.Duration;
import java.util.HashSet;
import java.util.concurrent.TimeUnit;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.cache.CacheMono;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Signal;
import reactor.util.retry.Retry;
import reactor.util.retry.RetrySpec;

import ru.yandex.intranet.d.web.security.model.YaAuthenticationToken;
import ru.yandex.intranet.d.web.security.model.YaCredentials;
import ru.yandex.intranet.d.web.security.model.YaPrincipal;
import ru.yandex.intranet.d.web.security.tvm.model.CheckedServiceTicket;
import ru.yandex.intranet.d.web.security.tvm.model.CheckedUserTicket;
import ru.yandex.intranet.d.web.security.tvm.model.ValidServiceTicket;
import ru.yandex.intranet.d.web.security.tvm.model.ValidUserTicket;

/**
 * TVM ticket checker.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@Component
@Profile({"dev", "testing", "production"})
public class TvmTicketChecker {

    private final TvmClient tvmClient;
    private final long ownTvmId;
    private final String requiredScope;
    private final Cache<String, CheckedServiceTicket> serviceTicketCache;
    private final Cache<String, CheckedUserTicket> userTicketCache;

    public TvmTicketChecker(TvmClient tvmClient, @Value("${tvm.ownId}") long ownTvmId,
                            @Value("${oauth.requiredScope}") String requiredScope) {
        this.tvmClient = tvmClient;
        this.ownTvmId = ownTvmId;
        this.requiredScope = requiredScope;
        this.serviceTicketCache = CacheBuilder.newBuilder()
                .expireAfterWrite(1, TimeUnit.MINUTES)
                .maximumSize(1000)
                .build();
        this.userTicketCache = CacheBuilder.newBuilder()
                .expireAfterWrite(1, TimeUnit.MINUTES)
                .maximumSize(1000)
                .build();
    }

    public Mono<YaAuthenticationToken> checkUser(String userTicket, String serviceTicket) {
        Mono<ValidServiceTicket> serviceCheck = checkServiceTicketCached(serviceTicket)
                .flatMap(this::checkServiceResult);
        Mono<ValidUserTicket> userCheck = checkUserTicketCached(userTicket)
                .flatMap(this::checkUserResult);
        return Mono.zip(serviceCheck, userCheck).flatMap(t -> {
            ValidServiceTicket service = t.getT1();
            ValidUserTicket user = t.getT2();
            boolean noValidScope = !user.getScopes().contains("bb:password")
                    && !user.getScopes().contains("bb:sessionid") && !user.getScopes().contains(requiredScope);
            if (noValidScope) {
                return Mono.empty();
            }
            YaPrincipal principal = new YaPrincipal(String.valueOf(user.getDefaultUid()), service.getSource(),
                    null, null, new HashSet<>(user.getScopes()));
            return Mono.just(new YaAuthenticationToken(principal, new YaCredentials()));
        });
    }

    public Mono<YaAuthenticationToken> checkService(String serviceTicket) {
        Mono<ValidServiceTicket> serviceCheck = checkServiceTicketCached(serviceTicket)
                .flatMap(this::checkServiceResult);
        return serviceCheck.map(t -> {
            YaPrincipal principal = new YaPrincipal(null, t.getSource(), null, null, new HashSet<>(t.getScopes()));
            return new YaAuthenticationToken(principal, new YaCredentials());
        });
    }

    private Mono<ValidServiceTicket> checkServiceResult(CheckedServiceTicket ticket) {
        if (ticket.getValid().isEmpty()) {
            return Mono.empty();
        }
        if (ticket.getValid().get().getDestination() != ownTvmId) {
            return Mono.empty();
        }
        return Mono.just(ticket.getValid().get());
    }

    private Mono<ValidUserTicket> checkUserResult(CheckedUserTicket ticket) {
        if (ticket.getValid().isEmpty()) {
            return Mono.empty();
        }
        if (ticket.getValid().get().getDefaultUid() == 0L) {
            return Mono.empty();
        }
        return Mono.just(ticket.getValid().get());
    }

    private Retry retry() {
        return RetrySpec.fixedDelay(1, Duration.ofSeconds(3)).filter(e -> {
            if (e instanceof WebClientResponseException) {
                WebClientResponseException webException = (WebClientResponseException) e;
                return webException.getRawStatusCode() >= 500 && webException.getRawStatusCode() < 600;
            }
            return !(e instanceof IllegalStateException);
        });
    }

    private Mono<CheckedServiceTicket> checkServiceTicketCached(String serviceTicket) {
        return CacheMono.lookup(this::getServiceTicketCheckFromCache, serviceTicket)
                .onCacheMissResume(() -> checkServiceTicketWithRetries(serviceTicket))
                .andWriteWith(this::putCheckedServiceTicketToCache);
    }

    private Mono<Signal<? extends CheckedServiceTicket>> getServiceTicketCheckFromCache(String k) {
        return Mono.justOrEmpty(serviceTicketCache.getIfPresent(k)).map(Signal::next);
    }

    private Mono<CheckedServiceTicket> checkServiceTicketWithRetries(String serviceTicket) {
        return tvmClient.checkServiceTicket(String.valueOf(ownTvmId), serviceTicket)
                .retryWhen(retry());
    }

    private Mono<Void> putCheckedServiceTicketToCache(String key, Signal<? extends CheckedServiceTicket> value) {
        return Mono.fromRunnable(() -> {
            if (!value.hasValue()) {
                return;
            }
            CheckedServiceTicket ticket = value.get();
            if (ticket != null) {
                serviceTicketCache.put(key, ticket);
            }
        });
    }

    private Mono<CheckedUserTicket> checkUserTicketCached(String userTicket) {
        return CacheMono.lookup(this::getUserTicketCheckFromCache, userTicket)
                .onCacheMissResume(() -> checkUserTicketWithRetries(userTicket))
                .andWriteWith(this::putCheckedUserTicketToCache);
    }

    private Mono<Signal<? extends CheckedUserTicket>> getUserTicketCheckFromCache(String k) {
        return Mono.justOrEmpty(userTicketCache.getIfPresent(k)).map(Signal::next);
    }

    private Mono<CheckedUserTicket> checkUserTicketWithRetries(String userTicket) {
        return tvmClient.checkUserTicket(userTicket)
                .retryWhen(retry());
    }

    private Mono<Void> putCheckedUserTicketToCache(String key, Signal<? extends CheckedUserTicket> value) {
        return Mono.fromRunnable(() -> {
            if (!value.hasValue()) {
                return;
            }
            CheckedUserTicket ticket = value.get();
            if (ticket != null) {
                userTicketCache.put(key, ticket);
            }
        });
    }

}
