package ru.yandex.infra.stage;

import java.time.Duration;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.infra.controller.metrics.GaugeRegistry;
import ru.yandex.infra.controller.metrics.GolovanableGauge;
import ru.yandex.infra.controller.util.ExceptionUtils;
import ru.yandex.infra.controller.yp.YpObjectStatusRepository;
import ru.yandex.infra.stage.concurrent.SerialExecutor;
import ru.yandex.infra.stage.dto.StageStatus;
import ru.yandex.infra.stage.protobuf.Converter;
import ru.yandex.infra.stage.util.AdaptiveRateLimiter;
import ru.yandex.infra.stage.util.GeneralUtils;
import ru.yandex.infra.stage.util.StoppableBackOff;
import ru.yandex.infra.stage.util.StoppableBackOffExecution;
import ru.yandex.yp.client.api.TStageSpec;
import ru.yandex.yp.client.api.TStageStatus;
import ru.yandex.yp.model.YpErrorCodes;

public class StageStatusSenderImpl implements StageStatusSender {
    private static final Logger LOG = LoggerFactory.getLogger(StageStatusSenderImpl.class);

    static final String METRIC_STATUS_UPDATE_SCHEDULED = "status_update_scheduled";
    static final String METRIC_STATUS_UPDATE_IN_PROGRESS = "status_update_in_progress";
    static final String METRIC_FAILED_SEND_STATUS_COUNT = "failed_send_status_count";
    static final String METRIC_SEND_STATUS_REQUESTS_TOTAL = "send_status_requests_total";
    static final String METRIC_SEND_STATUS_REQUESTS_FAILED = "send_status_requests_failed";
    static final String METRIC_SEND_STATUS_RPS_LIMIT = "send_status_rps_limit";

    private final YpObjectStatusRepository<TStageSpec, TStageStatus> ypRepository;
    private final SerialExecutor serialExecutor;
    private final Converter converter;
    private final AdaptiveRateLimiter rateLimiter;
    private final StoppableBackOff backoff;

    // Metrics may be accessed from other threads
    private final Map<String, TStageStatus> scheduledStatusesByStageId = new ConcurrentHashMap<>();
    private final Set<String> stagesWithFailedSendStatus = ConcurrentHashMap.newKeySet();
    private final Map<String, CompletableFuture<Boolean>> inProgressFutures = new ConcurrentHashMap<>();
    private final AtomicLong metricSendStatusRequestsTotal = new AtomicLong();
    private final AtomicLong metricSendStatusRequestsFailed = new AtomicLong();

    public StageStatusSenderImpl(YpObjectStatusRepository<TStageSpec, TStageStatus> ypRepository,
                                 SerialExecutor serialExecutor,
                                 GaugeRegistry registry,
                                 Converter converter,
                                 AdaptiveRateLimiter rateLimiter,
                                 Duration initialRetryTimeout,
                                 Duration maxRetryTimeout) {
        this.ypRepository = ypRepository;
        this.serialExecutor = serialExecutor;
        this.converter = converter;
        this.rateLimiter = rateLimiter;
        this.backoff = GeneralUtils.getTimeoutFactory(initialRetryTimeout.toMillis(), maxRetryTimeout.toMillis());

        registry.add(METRIC_STATUS_UPDATE_SCHEDULED, scheduledStatusesByStageId::size);
        registry.add(METRIC_STATUS_UPDATE_IN_PROGRESS, inProgressFutures::size);
        registry.add(METRIC_FAILED_SEND_STATUS_COUNT, stagesWithFailedSendStatus::size);
        registry.add(METRIC_SEND_STATUS_RPS_LIMIT, rateLimiter::getRate);
        registry.add(METRIC_SEND_STATUS_REQUESTS_TOTAL, new GolovanableGauge<>(metricSendStatusRequestsTotal::get, "dmmm"));
        registry.add(METRIC_SEND_STATUS_REQUESTS_FAILED, new GolovanableGauge<>(metricSendStatusRequestsFailed::get, "dmmm"));
    }

    @Override
    public void save(String stageId, StageStatus status) {
        save(stageId, converter.toProto(status));
    }

    private void save(String stageId, TStageStatus status) {
        if (scheduledStatusesByStageId.put(stageId, status) == null) {
            inProgressFutures.put(stageId, saveStatus(stageId));
        } else {
            LOG.info("Postpone sending status for stage {} to yp, as not received result from previous request", stageId);
        }
    }

    @Override
    public void cancelScheduledStatusUpdate(String stageId) {
        scheduledStatusesByStageId.remove(stageId);
        stagesWithFailedSendStatus.remove(stageId);
    }

    @VisibleForTesting
    public Map<String, CompletableFuture<Boolean>> getInProgressFutures() {
        return inProgressFutures;
    }

    private CompletableFuture<Boolean> saveStatus(String stageId) {
        AtomicReference<TStageStatus> statusInProgress = new AtomicReference<>();

        final StoppableBackOffExecution backOffExecution = backoff.startStoppable();
        return serialExecutor.executeOrRetry(() -> ypUpdateStageStatus(stageId, statusInProgress),
                response -> processSendStatusSuccess(stageId, statusInProgress, response),
                error -> processSendStatusError(stageId, error, backOffExecution),
                backOffExecution);
    }

    private boolean isYPLimitsReached(Throwable error) {
        return ExceptionUtils.tryExtractYpError(error)
                .filter(ypError -> ypError.getCode() == YpErrorCodes.REQUEST_THROTTLED ||
                        ypError.getCode() == YpErrorCodes.ACCOUNT_LIMIT_EXCEEDED).isPresent();
    }

    private CompletableFuture<Boolean> ypUpdateStageStatus(String stageId, AtomicReference<TStageStatus> statusInProgress) {
        TStageStatus newStatus = scheduledStatusesByStageId.get(stageId);
        statusInProgress.set(newStatus);
        if (newStatus == null) {
            return CompletableFuture.completedFuture(false);
        }
        if(!rateLimiter.tryAcquire()) {
            LOG.info("Skipping updated stage statuses due to rate limits. Queue size {}, InProgress {}, Failed {}",
                    scheduledStatusesByStageId.size(), inProgressFutures.size(), stagesWithFailedSendStatus.size());
            return CompletableFuture.failedFuture(new RuntimeException("Status update postponed due to rate limits for stage: " + stageId));
        }

        LOG.info("Sending status for stage: {}", stageId);
        metricSendStatusRequestsTotal.incrementAndGet();

        rateLimiter.incrementAndGet();
        return ypRepository.saveStatus(stageId, newStatus)
                .whenComplete((x, error) -> rateLimiter.decrementAndGet());
    }

    private void processSendStatusSuccess(String stageId, AtomicReference<TStageStatus> statusInProgress, boolean statusWasSaved) {
        stagesWithFailedSendStatus.remove(stageId);
        inProgressFutures.remove(stageId);
        if (statusWasSaved) {
            LOG.info("Stage status was sent successfully: {}", stageId);
        } else {
            LOG.info("Stage {} has been removed, will drop stored status on next spec update", stageId);
        }

        TStageStatus targetStatus = scheduledStatusesByStageId.remove(stageId);
        TStageStatus storedStatus = statusInProgress.get();
        //if target status was updated while sending previous status
        if (targetStatus != storedStatus) {
            if (scheduledStatusesByStageId.putIfAbsent(stageId, targetStatus) == null) {
                inProgressFutures.put(stageId, saveStatus(stageId));
            }
        }
    }

    private void processSendStatusError(String stageId, Throwable error, StoppableBackOffExecution backOffExecution) {
        metricSendStatusRequestsFailed.incrementAndGet();
        LOG.error("Failed to send status for stage: {}", stageId, ExceptionUtils.stripCompletionException(error));
        if (isYPLimitsReached(error)) {
            rateLimiter.registerFailedResponse();
        }

        if (scheduledStatusesByStageId.get(stageId) == null) {
            inProgressFutures.remove(stageId);
            stagesWithFailedSendStatus.remove(stageId);
            backOffExecution.stop();
            return;
        }

        stagesWithFailedSendStatus.add(stageId);
    }
}
