package ru.yandex.intranet.d.web.security.blackbox;

import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.client.WebClientResponseException;
import reactor.cache.CacheMono;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Signal;
import reactor.util.retry.Retry;
import reactor.util.retry.RetrySpec;

import ru.yandex.intranet.d.web.security.blackbox.model.BlackboxException;
import ru.yandex.intranet.d.web.security.blackbox.model.CheckedOAuthToken;
import ru.yandex.intranet.d.web.security.blackbox.model.CheckedSessionId;
import ru.yandex.intranet.d.web.security.model.YaAuthenticationToken;
import ru.yandex.intranet.d.web.security.model.YaCredentials;
import ru.yandex.intranet.d.web.security.model.YaPrincipal;
import ru.yandex.intranet.d.web.security.tvm.TvmClient;
import ru.yandex.intranet.d.web.security.tvm.model.TvmTicket;

/**
 * Blackbox authentication checker.
 *
 * @author Dmitriy Timashov <dm-tim@yandex-team.ru>
 */
@Component
@Profile({"dev", "testing", "production"})
public class BlackboxAuthChecker {

    private final TvmClient tvmClient;
    private final BlackboxClient blackboxClient;
    private final long ownTvmId;
    private final String requiredScope;
    private final long blackboxTvmId;
    private final String host;
    private final Cache<Long, String> serviceTicketCache;
    private final Cache<SessionIdKey, CheckedSessionId> sessionIdCache;
    private final Cache<OAuthTokenKey, CheckedOAuthToken> oauthTokenCache;

    public BlackboxAuthChecker(TvmClient tvmClient,
                               BlackboxClient blackboxClient,
                               @Value("${tvm.ownId}") long ownTvmId,
                               @Value("${oauth.requiredScope}") String requiredScope,
                               @Value("${blackbox.tvmId}") long blackboxTvmId,
                               @Value("${blackbox.host}") String host) {
        this.tvmClient = tvmClient;
        this.blackboxClient = blackboxClient;
        this.ownTvmId = ownTvmId;
        this.requiredScope = requiredScope;
        this.blackboxTvmId = blackboxTvmId;
        this.host = host;
        this.serviceTicketCache = CacheBuilder.newBuilder()
                .expireAfterWrite(1, TimeUnit.HOURS)
                .maximumSize(10)
                .build();
        this.sessionIdCache = CacheBuilder.newBuilder()
                .expireAfterWrite(1, TimeUnit.MINUTES)
                .maximumSize(1000)
                .build();
        this.oauthTokenCache = CacheBuilder.newBuilder()
                .expireAfterWrite(1, TimeUnit.MINUTES)
                .maximumSize(1000)
                .build();
    }

    public Mono<YaAuthenticationToken> checkSessionId(String sessionId, String userIp) {
        return checkSessionId(sessionId, null, userIp);
    }

    public Mono<YaAuthenticationToken> checkSessionId(String sessionId, String sslSessionId, String userIp) {
        return getServiceTicketCached().flatMap(serviceTicket ->
                checkSessionIdCached(sessionId, sslSessionId, userIp, serviceTicket).flatMap(checked -> {
                    if (checked.getValid().isEmpty()) {
                        return Mono.empty();
                    }
                    if (!checked.getValid().get().isSecure()) {
                        return Mono.empty();
                    }
                    String uid = checked.getValid().get().getUid();
                    YaPrincipal principal = new YaPrincipal(uid, null, null, null, Set.of());
                    return Mono.just(new YaAuthenticationToken(principal, new YaCredentials()));
                }));
    }

    public Mono<YaAuthenticationToken> checkOauthToken(String oauthToken, String userIp) {
        return getServiceTicketCached().flatMap(serviceTicket ->
                checkOAuthTokenCached(oauthToken, userIp, serviceTicket).flatMap(checked -> {
                    if (checked.getValid().isEmpty()) {
                        return Mono.empty();
                    }
                    if (!checked.getValid().get().getScopes().contains(requiredScope)) {
                        return Mono.empty();
                    }
                    String uid = checked.getValid().get().getUid();
                    String clientId = checked.getValid().get().getClientId();
                    String clientName = checked.getValid().get().getClientName();
                    Set<String> scopes = checked.getValid().get().getScopes();
                    YaPrincipal principal = new YaPrincipal(uid, null, clientId, clientName, scopes);
                    return Mono.just(new YaAuthenticationToken(principal, new YaCredentials()));
                }));
    }

    private Mono<String> getServiceTicketCached() {
        return CacheMono.lookup(this::getServiceTicketFromCache, blackboxTvmId)
                .onCacheMissResume(this::getServiceTicketWithRetries)
                .andWriteWith(this::putServiceTicketToCache);
    }

    private Mono<Signal<? extends String>> getServiceTicketFromCache(long k) {
        return Mono.justOrEmpty(serviceTicketCache.getIfPresent(k)).map(Signal::next);
    }

    private Mono<String> getServiceTicketWithRetries() {
        return tvmClient.tickets(String.valueOf(ownTvmId), List.of(String.valueOf(blackboxTvmId)))
                .retryWhen(retryServiceTicket()).flatMap(tickets -> {
                    Optional<String> ticket = tickets.values().stream()
                            .filter(t -> t.getTvmId() == blackboxTvmId && t.getTicket().isPresent())
                            .findFirst().flatMap(TvmTicket::getTicket);
                    if (ticket.isEmpty()) {
                        return Mono.error(new IllegalStateException(
                                "TVM ticket for blackbox destination is missing"));
                    }
                    return Mono.just(ticket.get());
        });
    }

    private Mono<Void> putServiceTicketToCache(long key, Signal<? extends String> value) {
        return Mono.fromRunnable(() -> {
            if (!value.hasValue()) {
                return;
            }
            String ticket = value.get();
            if (ticket != null) {
                serviceTicketCache.put(key, ticket);
            }
        });
    }

    private Mono<CheckedSessionId> checkSessionIdCached(String sessionId, String sslSessionId, String userIp,
                                                        String serviceTicket) {
        SessionIdKey key = new SessionIdKey(sessionId, sslSessionId, userIp);
        return CacheMono.lookup(this::getSessionIdCheckFromCache, key)
                .onCacheMissResume(() -> getSessionIdCheckWithRetries(sessionId, sslSessionId, userIp, serviceTicket))
                .andWriteWith(this::putSessionIdCheckToCache);
    }

    private Mono<Signal<? extends CheckedSessionId>> getSessionIdCheckFromCache(SessionIdKey k) {
        return Mono.justOrEmpty(sessionIdCache.getIfPresent(k)).map(Signal::next);
    }

    private Mono<CheckedSessionId> getSessionIdCheckWithRetries(String sessionId, String sslSessionId,
                                                                String userIp, String serviceTicket) {
        return blackboxClient.sessionId(serviceTicket, sessionId, userIp, host, sslSessionId)
                .retryWhen(retryBlackbox());
    }

    private Mono<Void> putSessionIdCheckToCache(SessionIdKey key, Signal<? extends CheckedSessionId> value) {
        return Mono.fromRunnable(() -> {
            if (!value.hasValue()) {
                return;
            }
            CheckedSessionId sessionId = value.get();
            if (sessionId != null) {
                sessionIdCache.put(key, sessionId);
            }
        });
    }

    private Mono<CheckedOAuthToken> checkOAuthTokenCached(String oauthToken, String userIp,
                                                        String serviceTicket) {
        OAuthTokenKey key = new OAuthTokenKey(oauthToken, userIp);
        return CacheMono.lookup(this::getOAuthTokenCheckFromCache, key)
                .onCacheMissResume(() -> getOAuthTokenCheckWithRetries(oauthToken, userIp, serviceTicket))
                .andWriteWith(this::putOAuthTokenCheckToCache);
    }

    private Mono<Signal<? extends CheckedOAuthToken>> getOAuthTokenCheckFromCache(OAuthTokenKey k) {
        return Mono.justOrEmpty(oauthTokenCache.getIfPresent(k)).map(Signal::next);
    }

    private Mono<CheckedOAuthToken> getOAuthTokenCheckWithRetries(String oauthToken, String userIp,
                                                                  String serviceTicket) {
        return blackboxClient.oauth(serviceTicket, oauthToken, userIp, List.of(requiredScope))
                .retryWhen(retryBlackbox());
    }

    private Mono<Void> putOAuthTokenCheckToCache(OAuthTokenKey key, Signal<? extends CheckedOAuthToken> value) {
        return Mono.fromRunnable(() -> {
            if (!value.hasValue()) {
                return;
            }
            CheckedOAuthToken oauthToken = value.get();
            if (oauthToken != null) {
                oauthTokenCache.put(key, oauthToken);
            }
        });
    }

    private Retry retryServiceTicket() {
        return RetrySpec.fixedDelay(1, Duration.ofSeconds(3)).filter(e -> {
            if (e instanceof WebClientResponseException) {
                WebClientResponseException webException = (WebClientResponseException) e;
                return webException.getRawStatusCode() >= 500 && webException.getRawStatusCode() < 600;
            }
            return !(e instanceof IllegalStateException);
        });
    }

    private Retry retryBlackbox() {
        return RetrySpec.fixedDelay(1, Duration.ofSeconds(3)).filter(e -> {
            if (e instanceof WebClientResponseException) {
                WebClientResponseException webException = (WebClientResponseException) e;
                return webException.getRawStatusCode() >= 500 && webException.getRawStatusCode() < 600;
            }
            if (e instanceof BlackboxException) {
                BlackboxException blackboxException = (BlackboxException) e;
                return blackboxException.getId() == BlackboxException.DB_EXCEPTION;
            }
            return !(e instanceof IllegalStateException);
        });
    }

    private static final class SessionIdKey {

        private final String sessionId;
        private final String sslSessionId;
        private final String userIp;

        private SessionIdKey(String sessionId, String sslSessionId, String userIp) {
            this.sessionId = sessionId;
            this.sslSessionId = sslSessionId;
            this.userIp = userIp;
        }

        public String getSessionId() {
            return sessionId;
        }

        public Optional<String> getSslSessionId() {
            return Optional.ofNullable(sslSessionId);
        }

        public String getUserIp() {
            return userIp;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            SessionIdKey that = (SessionIdKey) o;
            return Objects.equals(getSessionId(), that.getSessionId()) &&
                    Objects.equals(getSslSessionId(), that.getSslSessionId()) &&
                    Objects.equals(getUserIp(), that.getUserIp());
        }

        @Override
        public int hashCode() {
            return Objects.hash(getSessionId(), getSslSessionId(), getUserIp());
        }

        @Override
        public String toString() {
            return "SessionIdKey{" +
                    "sessionId='***'" +
                    ", sslSessionId='***'" +
                    ", userIp='" + userIp + '\'' +
                    '}';
        }

    }

    private static final class OAuthTokenKey {

        private final String oauthToken;
        private final String userIp;

        private OAuthTokenKey(String oauthToken, String userIp) {
            this.oauthToken = oauthToken;
            this.userIp = userIp;
        }

        public String getOauthToken() {
            return oauthToken;
        }

        public String getUserIp() {
            return userIp;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            OAuthTokenKey that = (OAuthTokenKey) o;
            return Objects.equals(getOauthToken(), that.getOauthToken()) &&
                    Objects.equals(getUserIp(), that.getUserIp());
        }

        @Override
        public int hashCode() {
            return Objects.hash(getOauthToken(), getUserIp());
        }

        @Override
        public String toString() {
            return "OAuthTokenKey{" +
                    "oauthToken='***'" +
                    ", userIp='" + userIp + '\'' +
                    '}';
        }

    }

}
