package ru.yandex.direct.api.v5.ratelimit;

import java.time.Clock;
import java.time.Duration;
import java.util.Map;
import java.util.Optional;

import javax.annotation.ParametersAreNonnullByDefault;

import io.lettuce.core.SetArgs;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Component;
import org.springframework.ws.context.MessageContext;
import org.springframework.ws.server.EndpointInterceptor;

import ru.yandex.direct.api.v5.security.ApiAuthenticationSource;
import ru.yandex.direct.api.v5.ws.WsUtils;
import ru.yandex.direct.api.v5.ws.annotation.ApiMethod;
import ru.yandex.direct.common.db.PpcPropertiesSupport;
import ru.yandex.direct.common.db.PpcProperty;
import ru.yandex.direct.common.db.PpcPropertyNames;
import ru.yandex.direct.common.lettuce.LettuceConnectionProvider;
import ru.yandex.direct.common.lettuce.LettuceExecuteException;

import static ru.yandex.direct.common.configuration.RedisConfiguration.LETTUCE;


/**
 * {@link RateLimitInterceptor} ограничивает частоту запросов к некоторым методам API
 * для клиентов и/или для приложений: можно делать не больше M запросов за N секунд.
 * <p>
 * 1) Чтобы задать ограничения на клиентов, метод надо пометить аннотацией, например:
 * <code>
 * {@literal @}RateLimitByClient(periodInSeconds = 60, maxRequests = 20)
 * {@literal @}ApiMethod(service = SERVICE_NAME, operation = "get")
 * public GetResponse get({@literal @}ApiRequest GetRequest parameters) throws ApiException {
 * ...
 * }
 * </code>
 * <p>
 * Аннотацию {@link RateLimitByClient} можно указывать несколько раз.
 * <p>
 * 2) Чтобы задать ограничения на приложения, необходимо задать проперти с максимальным количеством
 * запросов для метода, добавить ее в мапу {@link RateLimitInterceptor#maxRequestsByApplicationProperties},
 * а также пометить метод аннотацией, например:
 * <code>
 * {@literal @}RateLimitByApplication(periodInSeconds = 60)
 * {@literal @}ApiMethod(service = SERVICE_NAME, operation = "get")
 * public GetResponse get({@literal @}ApiRequest GetRequest parameters) throws ApiException {
 * ...
 * }
 * </code>
 * <p>
 */

@ParametersAreNonnullByDefault
@Component
public class RateLimitInterceptor implements EndpointInterceptor {
    private static final Logger logger = LoggerFactory.getLogger(RateLimitInterceptor.class);

    private final ApiAuthenticationSource apiAuthenticationSource;
    private final Clock clock;
    private LettuceConnectionProvider lettuce;
    private Map<Pair<String/*service*/, String/*operation*/>, PpcProperty<Integer>> maxRequestsByApplicationProperties;

    RateLimitInterceptor(
            @Qualifier(LETTUCE) LettuceConnectionProvider lettuce,
            ApiAuthenticationSource apiAuthenticationSource,
            PpcPropertiesSupport ppcPropertiesSupport, Clock clock) {
        this.lettuce = lettuce;
        this.apiAuthenticationSource = apiAuthenticationSource;
        this.clock = clock;
        PpcProperty<Integer> hasSearchVolumeMaxRequestsByApplicationProperty =
                ppcPropertiesSupport.get(PpcPropertyNames.HAS_SEARCH_VOLUME_MAX_REQUESTS_BY_APPLICATION,
                        Duration.ofMinutes(5));
        maxRequestsByApplicationProperties = Map.of(Pair.of("keywordsresearch", "hasSearchVolume"),
                hasSearchVolumeMaxRequestsByApplicationProperty);
    }

    @Autowired
    public RateLimitInterceptor(
            @Qualifier(LETTUCE) LettuceConnectionProvider connectionProvider,
            ApiAuthenticationSource apiAuthenticationSource,
            PpcPropertiesSupport ppcPropertiesSupport) {
        this(connectionProvider, apiAuthenticationSource, ppcPropertiesSupport, Clock.systemDefaultZone());
    }

    @Override
    public boolean handleRequest(MessageContext messageContext, Object endpoint) {
        RateLimitByClient[] rateLimits = WsUtils.getEndpointMethodAnnotations(endpoint, RateLimitByClient.class);
        ApiMethod apiMethod = WsUtils.getEndpointMethodMeta(endpoint, ApiMethod.class);

        long clientId = apiAuthenticationSource.getSubclient().getClientId().asLong();
        checkClientRateLimits(rateLimits, apiMethod, clientId);

        Optional<RateLimitByApplication> rateLimitOpt =
                WsUtils.getEndpointMethodMetaOpt(endpoint, RateLimitByApplication.class);
        String applicationId = apiAuthenticationSource.getApplicationId();
        rateLimitOpt.ifPresent(rateLimit -> checkApplicationRateLimits(rateLimit, apiMethod, applicationId));
        return true;
    }

    long incrAndGetRedisKey(String redisKey, long expireAfterSecs) {
        // Устанавливаем начальное кол-во запросов, если еще не устанавливали ранее (NX),
        // с expiration-ом (EX) через expireAfterSecs секунд

        try {
            lettuce.call("redis:set",
                    com -> com.set(redisKey, "0", SetArgs.Builder.nx().ex(expireAfterSecs))
            );
            return lettuce.call("redis:incr", com -> com.incr(redisKey));
        } catch (LettuceExecuteException e) {
            logger.error("Can't increment and get redisKey: {}", redisKey, e);
            // Считаем что отказ Redis-а должен быть относительно нечастым событием, поэтому выбираем
            // счастье пользователей, разрешая им немного превысить ограничение по кол-ву запросов
            return 0;
        }
    }

    /**
     * Увеличить кол-во сделанных запросов клиентом и вернуть текущее значение, начиная
     * с момента startOfPeriodInSec
     */
    private long incrAndGetRequestMadeByClient(
            ApiMethod apiMethod, long clientId, RateLimitByClient rateLimit, long startOfPeriodInSec) {
        String redisKey = String.format("api5-ratelimit-%s-%s-%s-%s-%s",
                apiMethod.service(),
                apiMethod.operation(),
                rateLimit.periodInSeconds(),
                clientId,
                startOfPeriodInSec);

        return incrAndGetRedisKey(redisKey, rateLimit.periodInSeconds());
    }

    private void checkClientRateLimits(RateLimitByClient[] rateLimits, ApiMethod apiMethod, long clientId) {
        long currentTime = clock.instant().getEpochSecond();

        for (RateLimitByClient rateLimit : rateLimits) {
            /*
             * Эта реализация проверяет, что клиент делает не больше rateLimit.maxRequests()
             * в окне времени от startOfPeriod до startOfPeriod + rateLimit.periodInSeconds().
             * Это окно времени меняется каждые rateLimit.periodInSeconds() секунд.
             */
            long startOfPeriod = currentTime - currentTime % rateLimit.periodInSeconds();

            long requestsMade = incrAndGetRequestMadeByClient(apiMethod, clientId, rateLimit, startOfPeriod);
            if (requestsMade > rateLimit.maxRequests()) {
                throw new RateLimitExceededException();
            }
        }
    }

    /**
     * Увеличить кол-во сделанных запросов приложением и вернуть текущее значение, начиная
     * с момента startOfPeriodInSec
     */
    private long incrAndGetRequestMadeByApplication(
            ApiMethod apiMethod, String applicationId, RateLimitByApplication rateLimit, long startOfPeriodInSec) {
        String redisKey = String.format("api5-ratelimit-app-%s-%s-%s-%s-%s",
                apiMethod.service(),
                apiMethod.operation(),
                rateLimit.periodInSeconds(),
                applicationId,
                startOfPeriodInSec);

        return incrAndGetRedisKey(redisKey, rateLimit.periodInSeconds());
    }

    private void checkApplicationRateLimits(RateLimitByApplication rateLimit, ApiMethod apiMethod, String applicationId) {
        long currentTime = clock.instant().getEpochSecond();

        /*
         * Эта реализация проверяет, что приложение делает не больше количества запросов,
         * указанного в проперти для метода apiMethod в окне времени от startOfPeriod до
         * startOfPeriod + rateLimit.periodInSeconds().
         * Это окно времени меняется каждые rateLimit.periodInSeconds() секунд.
         */
        long startOfPeriod = currentTime - currentTime % rateLimit.periodInSeconds();
        PpcProperty<Integer> maxRequestsByApplicationProperty =
                maxRequestsByApplicationProperties.get(Pair.of(apiMethod.service(), apiMethod.operation()));
        if (maxRequestsByApplicationProperty != null && maxRequestsByApplicationProperty.get() != null) {
            long requestsMadeByApplication = incrAndGetRequestMadeByApplication(apiMethod,
                    applicationId, rateLimit, startOfPeriod);
            if (requestsMadeByApplication > maxRequestsByApplicationProperty.get()) {
                throw new RateLimitExceededException();
            }
        }
    }

    @Override
    public boolean handleResponse(MessageContext messageContext, Object endpoint) throws Exception {
        return true;
    }

    @Override
    public boolean handleFault(MessageContext messageContext, Object endpoint) throws Exception {
        return true;
    }

    @Override
    public void afterCompletion(MessageContext messageContext, Object endpoint, Exception ex) throws Exception {
        /* всю работу делает handleRequest, здесь делать нечего */
    }
}
