package ru.yandex.travel.hotels.searcher.partners;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.util.JsonFormat;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import io.opentracing.noop.NoopTracerFactory;
import lombok.AccessLevel;
import lombok.Setter;
import org.asynchttpclient.AsyncHttpClient;
import org.asynchttpclient.Dsl;
import org.asynchttpclient.Request;
import org.asynchttpclient.RequestBuilder;
import org.asynchttpclient.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.core.env.Environment;

import ru.yandex.misc.thread.factory.ThreadNameThreadFactory;
import ru.yandex.travel.commons.concurrent.TerminationSemaphore;
import ru.yandex.travel.commons.health.HealthCheckedSupplier;
import ru.yandex.travel.commons.messaging.KeyValueStorage;
import ru.yandex.travel.commons.messaging.MessageBus;
import ru.yandex.travel.commons.metrics.MetricsUtils;
import ru.yandex.travel.commons.proto.EErrorCode;
import ru.yandex.travel.commons.proto.ProtoUtils;
import ru.yandex.travel.commons.proto.TError;
import ru.yandex.travel.commons.retry.Retry;
import ru.yandex.travel.commons.retry.RetryRateLimiter;
import ru.yandex.travel.commons.retry.SpeculativeRetryStrategy;
import ru.yandex.travel.hotels.common.refunds.RefundRules;
import ru.yandex.travel.hotels.common.refunds.RefundType;
import ru.yandex.travel.hotels.proto.EPartnerId;
import ru.yandex.travel.hotels.proto.ERefundType;
import ru.yandex.travel.hotels.proto.ERequestClass;
import ru.yandex.travel.hotels.proto.TOffer;
import ru.yandex.travel.hotels.proto.TOfferDataMessage;
import ru.yandex.travel.hotels.proto.TPlaceholder;
import ru.yandex.travel.hotels.proto.TRefundRule;
import ru.yandex.travel.hotels.proto.TRequestAttribution;
import ru.yandex.travel.hotels.proto.TSearchOffersReq;
import ru.yandex.travel.hotels.proto.TSearcherMessage;
import ru.yandex.travel.hotels.searcher.BatchedTaskQueue;
import ru.yandex.travel.hotels.searcher.DefaultOfferSearchService;
import ru.yandex.travel.hotels.searcher.PartnerBean;
import ru.yandex.travel.hotels.searcher.QueueConsumer;
import ru.yandex.travel.hotels.searcher.Task;
import ru.yandex.travel.commons.rate.Throttler;
import ru.yandex.travel.hotels.searcher.cold.ColdService;
import ru.yandex.travel.hotels.searcher.logging.HttpRequestLogger;

public abstract class AbstractPartnerTaskHandler<T extends AbstractPartnerTaskHandlerProperties> implements PartnerTaskHandler, DisposableBean, InitializingBean {
    public static final String ALL_PURPOSES = "__all__";
    private static final String HTTP_REQUEST_ID_HEADER = "x-ya-request-id";
    private static final Long MAX_VALID_PRICE = 1_000_000_000L;
    final ScheduledExecutorService executor;
    private final AsyncHttpClient client;
    private final ImmutableMap<ERequestClass, QueueConsumer> queueConsumers;

    private final TerminationSemaphore runningTaskTerminationSemaphore;
    private final TerminationSemaphore droppedTaskTerminationSemaphore;
    private final JsonFormat.Printer jsonPrinter;
    private final RetryRateLimiter speculativeRetryRateLimiter;
    private final AtomicLong batchCount;
    private final ImmutableMap<TaskEvent, ImmutableMap<ERequestClass, Counter>> taskCounters;
    private final Counter emptyOfferListsCounter;
    private final Counter unknownPansionOffersCounter;
    private final Counter invalidOfferPriceCounter;
    private final Counter speculativeRetryCounter;
    private final ImmutableMap<ERequestClass, Timer> taskTotalTimers;
    private final ImmutableMap<ERequestClass, Timer> taskStallTimers;
    private final ImmutableMap<ERequestClass, Timer> taskExecuteTimers;
    private final ImmutableMap<ERequestClass, Counter> offerCounters;
    private final ImmutableMap<ERequestClass, ImmutableMap<String, ImmutableMap<HttpStatus, Counter>>> httpCounters;
    private final ImmutableMap<ERequestClass, ImmutableMap<String, Counter>> httpByteCounters;
    private final ImmutableMap<ERequestClass, ImmutableMap<String, Timer>> httpTimers;
    private final ImmutableMap<ERequestClass, DistributionSummary> batchDistributionSummaries;
    private final ImmutableMap<Integer, HttpStatus> integerToStatus;
    protected HttpRequestLogger httpLogger;
    protected Logger logger;
    protected Logger resultLogger;
    protected EPartnerId partnerId;
    protected T config;
    private final Retry retryHelper;
    @Autowired
    @Qualifier("healthCheckedKeyValueStorageSupplier")
    HealthCheckedSupplier<KeyValueStorage> storageSupplier;
    @Autowired
    @Qualifier("healthCheckedSearchFlowOfferDataStorageSupplier")
    @Setter(AccessLevel.PACKAGE)
    HealthCheckedSupplier<KeyValueStorage> searchFlowOfferDataStorageSupplier;
    @Autowired
    @Qualifier("messageBus")
    @Setter(AccessLevel.PACKAGE)
    private MessageBus messageBus;
    @Autowired
    @Setter(AccessLevel.PACKAGE)
    private Environment env;
    @Autowired
    @Setter(AccessLevel.PACKAGE)
    private ColdService coldService;

    AbstractPartnerTaskHandler(T config) {
        partnerId = this.getClass().getDeclaredAnnotation(PartnerBean.class).value();
        executor = Executors.newSingleThreadScheduledExecutor(new ThreadNameThreadFactory(partnerId.toString() +
                "-handler-thread"));
        this.config = config;
        this.logger = LoggerFactory.getLogger(this.getClass());
        this.resultLogger = LoggerFactory.getLogger("ru.yandex.travel.hotels.searcher.ResultLogger");
        this.jsonPrinter = JsonFormat.printer().omittingInsignificantWhitespace().includingDefaultValueFields();
        this.speculativeRetryRateLimiter = new RetryRateLimiter(config.getSpeculativeRetryRateLimit());
        this.retryHelper = new Retry(NoopTracerFactory.create());
        this.client = Dsl.asyncHttpClient(Dsl.config()
                .setThreadPoolName(partnerId.toString())
                .build());

        this.queueConsumers = AbstractPartnerTaskHandler.queueConsumersForRequestClasses(
                config,
                partnerId,
                this::onRateLimitTaskDroppedAction,
                this::onConcurrentLimitTaskDroppedAction,
                this::startNow
        );

        runningTaskTerminationSemaphore = new TerminationSemaphore(logger, "Main");
        droppedTaskTerminationSemaphore = new TerminationSemaphore(logger, "Dropped");

        // Task Counters.
        emptyOfferListsCounter = Metrics.counter("searcher.partners.tasks.emptyOfferLists", "partner",
                partnerId.toString());
        unknownPansionOffersCounter = Metrics.counter("searcher.partners.tasks.unknownPansionOffers", "partner",
                partnerId.toString());
        invalidOfferPriceCounter = Metrics.counter("searcher.partners.tasks.invalidOfferPrices", "partner",
                partnerId.toString());
        speculativeRetryCounter = Metrics.counter("searcher.partners.tasks.speculativeRetries", "partner",
                partnerId.toString());
        EnumMap<TaskEvent, ImmutableMap<ERequestClass, Counter>> countersByEvent = new EnumMap<>(TaskEvent.class);
        for (TaskEvent event : TaskEvent.values()) {
            EnumMap<ERequestClass, Counter> countersByClazz = new EnumMap<>(ERequestClass.class);
            for (ERequestClass clazz : ERequestClass.values()) {
                if (clazz == ERequestClass.UNRECOGNIZED) {
                    continue;
                }
                Counter counter = Counter.builder("searcher.partners.tasks." + event.toString().toLowerCase() + "Count")
                        .tag("partner", partnerId.toString())
                        .tag("class", clazz.toString())
                        .register(Metrics.globalRegistry);
                countersByClazz.put(clazz, counter);
            }
            countersByEvent.put(event, ImmutableMap.copyOf(countersByClazz));
        }
        taskCounters = ImmutableMap.copyOf(countersByEvent);

        EnumMap<ERequestClass, Counter> countersByClazz = new EnumMap<>(ERequestClass.class);
        for (ERequestClass clazz : ERequestClass.values()) {
            if (clazz == ERequestClass.UNRECOGNIZED) {
                continue;
            }
            Counter counter = Counter.builder("searcher.partners.tasks.offers")
                    .tag("partner", partnerId.toString())
                    .tag("class", clazz.toString())
                    .register(Metrics.globalRegistry);
            countersByClazz.put(clazz, counter);
        }
        offerCounters = ImmutableMap.copyOf(countersByClazz);
        // Task Timers.
        BiFunction<Timer.Builder, ERequestClass, Timer> configure = (builder, clazz) -> builder
                .tag("partner", partnerId.toString())
                .tag("class", clazz.toString())
                .serviceLevelObjectives(MetricsUtils.mediumDurationSla())
                .publishPercentiles(MetricsUtils.higherPercentiles())
                .register(Metrics.globalRegistry);
        EnumMap<ERequestClass, Timer> taskTotalTimersBuilder = new EnumMap<>(ERequestClass.class);
        EnumMap<ERequestClass, Timer> taskStallTimersBuilder = new EnumMap<>(ERequestClass.class);
        EnumMap<ERequestClass, Timer> taskExecuteTimersBuilder = new EnumMap<>(ERequestClass.class);
        for (ERequestClass clazz : ERequestClass.values()) {
            if (clazz == ERequestClass.UNRECOGNIZED) {
                continue;
            }
            taskTotalTimersBuilder.put(clazz, configure.apply(Timer.builder("searcher.partners.tasks.totalTime"),
                    clazz));
            taskStallTimersBuilder.put(clazz, configure.apply(Timer.builder("searcher.partners.tasks.stallTime"),
                    clazz));
            taskExecuteTimersBuilder.put(clazz, configure.apply(Timer.builder("searcher.partners.tasks.executeTime"),
                    clazz));
        }
        taskTotalTimers = ImmutableMap.copyOf(taskTotalTimersBuilder);
        taskStallTimers = ImmutableMap.copyOf(taskStallTimersBuilder);
        taskExecuteTimers = ImmutableMap.copyOf(taskExecuteTimersBuilder);

        // Batch Distribution summaries
        EnumMap<ERequestClass, DistributionSummary> batchDistributionSummariesBuilder =
                new EnumMap<>(ERequestClass.class);
        for (ERequestClass clazz : ERequestClass.values()) {
            if (clazz == ERequestClass.UNRECOGNIZED) {
                continue;
            }
            DistributionSummary d = DistributionSummary.builder("searcher.partners.tasks.batchSize")
                    .tag("partner", partnerId.toString())
                    .tag("class", clazz.toString())
                    .serviceLevelObjectives(MetricsUtils.countSla())
                    .publishPercentiles(MetricsUtils.iqrPercentiles())
                    .register(Metrics.globalRegistry);
            batchDistributionSummariesBuilder.put(clazz, d);
        }
        batchDistributionSummaries = ImmutableMap.copyOf(batchDistributionSummariesBuilder);

        List<String> purposes = new ArrayList<>(getHttpCallPurposes());
        purposes.add(ALL_PURPOSES);
        integerToStatus = ImmutableMap.of(
                2, HttpStatus.HTTP_2XX,
                3, HttpStatus.HTTP_3XX,
                4, HttpStatus.HTTP_4XX,
                5, HttpStatus.HTTP_5XX);
        // HTTP counters.
        EnumMap<ERequestClass, ImmutableMap<String, ImmutableMap<HttpStatus, Counter>>> httpCountersBuilder =
                new EnumMap<>(ERequestClass.class);
        EnumMap<ERequestClass, ImmutableMap<String, Counter>> httpByteCountersBuilder =
                new EnumMap<>(ERequestClass.class);

        for (ERequestClass clazz : ERequestClass.values()) {
            if (clazz == ERequestClass.UNRECOGNIZED) {
                continue;
            }
            Map<String, ImmutableMap<HttpStatus, Counter>> purposeCodeTypeBuilder = new HashMap<>();
            Map<String, Counter> purposeByteCounterBuilder = new HashMap<>();
            for (String purpose : purposes) {
                EnumMap<HttpStatus, Counter> codeTypeCounterBuilder = new EnumMap<>(HttpStatus.class);
                for (HttpStatus codeType : HttpStatus.values()) {
                    String name = codeType.name().toLowerCase();
                    String prefix = "http_";
                    if (name.startsWith(prefix)) {
                        name = name.substring(prefix.length());
                    }
                    Counter.Builder bldr = Counter.builder("searcher.partners.http.requests")
                            .tag("partner", partnerId.toString())
                            .tag("status", name)
                            .tag("class", clazz.toString());
                    if (!purpose.equals(ALL_PURPOSES)) {
                        bldr.tag("purpose", purpose);
                    }
                    Counter counter = bldr.register(Metrics.globalRegistry);
                    codeTypeCounterBuilder.put(codeType, counter);
                }
                purposeCodeTypeBuilder.put(purpose, ImmutableMap.copyOf(codeTypeCounterBuilder));
                Counter.Builder bldr = Counter.builder("searcher.partners.http.bytes")
                        .tag("partner", partnerId.toString())
                        .tag("class", clazz.toString());
                if (!purpose.equals(ALL_PURPOSES)) {
                    bldr.tag("purpose", purpose);
                }
                Counter counter = bldr.register(Metrics.globalRegistry);
                purposeByteCounterBuilder.put(purpose, counter);
            }
            httpCountersBuilder.put(clazz, ImmutableMap.copyOf(purposeCodeTypeBuilder));
            httpByteCountersBuilder.put(clazz, ImmutableMap.copyOf(purposeByteCounterBuilder));
        }
        httpCounters = ImmutableMap.copyOf(httpCountersBuilder);
        httpByteCounters = ImmutableMap.copyOf(httpByteCountersBuilder);

        // HTTP timers.
        EnumMap<ERequestClass, ImmutableMap<String, Timer>> httpTimersBuilder = new EnumMap<>(ERequestClass.class);
        for (ERequestClass clazz : ERequestClass.values()) {
            if (clazz == ERequestClass.UNRECOGNIZED) {
                continue;
            }
            Map<String, Timer> timersMapBuilder = new HashMap<>();
            for (String purpose : purposes) {
                Timer.Builder bldr = Timer.builder("searcher.partners.http.time")
                        .tag("partner", partnerId.toString())
                        .tag("class", clazz.toString())
                        .serviceLevelObjectives(MetricsUtils.mediumDurationSla())
                        .publishPercentiles(MetricsUtils.higherPercentiles());
                if (!purpose.equals(ALL_PURPOSES)) {
                    bldr.tag("purpose", purpose);
                }
                Timer timer = bldr.register(Metrics.globalRegistry);
                timersMapBuilder.put(purpose, timer);
            }
            httpTimersBuilder.put(clazz, ImmutableMap.copyOf(timersMapBuilder));
        }
        httpTimers = ImmutableMap.copyOf(httpTimersBuilder);

        // Gauges.
        queueConsumers.forEach(
                (requestClass, queueConsumer) -> {
                    var throttler = queueConsumer.getThrottler();
                    var queue = queueConsumer.getQueue();
                    Gauge.builder("searcher.partners.semaphoreValue", throttler, Throttler::getSemaphoreValue)
                            .tag("partner", partnerId.toString())
                            .tag("class", requestClass.toString())
                            .register(Metrics.globalRegistry);
                    Gauge.builder("searcher.partners.semaphoreLimit", throttler, Throttler::getSemaphoreLimit)
                            .tag("partner", partnerId.toString())
                            .tag("class", requestClass.toString())
                            .register(Metrics.globalRegistry);
                    Gauge.builder("searcher.partners.queueSize", queue, BatchedTaskQueue::getSize)
                            .tag("partner", partnerId.toString())
                            .tag("class", requestClass.toString())
                            .register(Metrics.globalRegistry);
                }
        );
        this.batchCount = new AtomicLong(0);
    }


    public static <T extends AbstractPartnerTaskHandlerProperties> ImmutableMap<ERequestClass, QueueConsumer> queueConsumersForRequestClasses(
            T config,
            EPartnerId partnerId,
            Consumer<? super Task> onRateLimitDrop,
            Consumer<? super Task> onConcurrencyLimitDrop,
            QueueConsumer.TaskStarter taskStarter
    ) {
        List<QueueConsumer> queueConsumer = List.of(
                new QueueConsumer.Builder(partnerId, ERequestClass.RC_BACKGROUND)
                        .withTaskStarter(taskStarter)
                        .withThrottlerParams(
                                config.getBackgroundRateLimit(), config.getBackgroundConcurrencyLimit(),
                                config.getRateLimiterBucket(), config.getRateLimiterWindow())
                        .withLimitedQueueSize(config.getBackgroundQueueLimit())
                        .withBatchSize(config.getMaxBatchSize())
                        .dontDropOnLimit(config.getRateLimiterBucket(), config.getConcurrencyReschedulePeriod())
                        .withQueueConsumerExecutor(Executors::newSingleThreadScheduledExecutor)
                        .withTaskExecutor((threadFactory) -> Executors.unconfigurableScheduledExecutorService(
                                Executors.newScheduledThreadPool(config.getBackgroundThreadPoolSize(), threadFactory)))
                        .build(),
                new QueueConsumer.Builder(partnerId, ERequestClass.RC_INTERACTIVE)
                        .withTaskStarter(taskStarter)
                        .withThrottlerParams(
                                config.getInteractiveRateLimit(), config.getInteractiveConcurrencyLimit(),
                                config.getRateLimiterBucket(), config.getRateLimiterWindow()
                        )
                        .withDropOnLimitActions(onRateLimitDrop, onConcurrencyLimitDrop)
                        .withBatchSize(config.getMaxBatchSize())
                        .withQueueConsumerExecutor(Executors::newSingleThreadScheduledExecutor)
                        .withTaskExecutor((threadFactory) -> Executors.unconfigurableScheduledExecutorService(
                                Executors.newScheduledThreadPool(config.getInteractiveThreadPoolSize(), threadFactory)))
                        .build(),
                new QueueConsumer.Builder(partnerId, ERequestClass.RC_CALENDAR)
                        .withTaskStarter(taskStarter)
                        .withThrottlerParams(
                                config.getCalendarRateLimit(), config.getCalendarConcurrencyLimit(),
                                config.getRateLimiterBucket(), config.getRateLimiterWindow())
                        .withLimitedQueueSize(config.getCalendarQueueLimit())
                        .withDropOnLimitActions(onRateLimitDrop, onConcurrencyLimitDrop)
                        .withBatchSize(config.getMaxBatchSize())
                        .withQueueConsumerExecutor(Executors::newSingleThreadScheduledExecutor)
                        .withTaskExecutor((threadFactory) -> Executors.unconfigurableScheduledExecutorService(
                                Executors.newScheduledThreadPool(config.getCalendarThreadPollSize(), threadFactory)))
                        .build()
        );
        var mapBuilder = ImmutableMap.<ERequestClass, QueueConsumer>builder();
        queueConsumer.forEach((qe) -> mapBuilder.put(qe.getRequestClass(), qe));
        return mapBuilder.build();
    }

    @Override
    public void afterPropertiesSet() {
        this.httpLogger = new HttpRequestLogger(HTTP_REQUEST_ID_HEADER, LoggerFactory.getLogger("ru.yandex.travel" +
                ".hotels.searcher.HttpLogger"),
                env, isMocked(), partnerId.getValueDescriptor().getName());
    }

    /////////////////
    @Override
    public void startHandle(List<Task> tasks) {
        Preconditions.checkArgument(!tasks.isEmpty());
        List<Task> acceptedTasks = new LinkedList<>();
        for (Task task : tasks) {
            try {
                checkTask(task);
            } catch (Throwable throwable) {
                rejectTask(task, throwable.getMessage());
                continue;
            }
            acceptedTasks.add(task);
        }
        acceptedTasks.forEach(this::acceptTask);
        startTasks(acceptedTasks);
    }

    private void rejectTask(Task task, String reason) {
        logger.info("Task {}: rejected: {}", task.getId(), reason);
        track(task, TaskEvent.REJECTED);
        exportToYT(new RuntimeException("rejected: " + reason), task, false);
        task.onComplete();
    }

    private void acceptTask(Task task) {
        logger.info("Task {}: accepted for execution (OfferCacheClientId={})", task.getId(),
                task.getRequest().getAttribution().getOfferCacheClientId());
        track(task, TaskEvent.ACCEPTED);
        exportToYT(null, task, true);
    }

    /**
     * Filters incoming tasks to indicate whether the task is guaranteed to cause an error on partner side (i.e. lead to
     * an HTTP 400 error due to input being incompatible with partner's API.
     * Tasks which are not correct from our (not partner's) point of view should be filtered earlier in the pipeline:
     * in {@link DefaultOfferSearchService#validateRequest(TSearchOffersReq)}
     */
    protected void checkTask(Task task) {
        if (!config.isEnabled()) {
            throw new IllegalStateException("Partner is disabled");
        }
    }

    private void onRateLimitTaskDroppedAction(Task t) {
        logger.info("Task {}: dropped due to rate limit violation", t.getId());
        t.onError(TError.newBuilder()
                .setCode(EErrorCode.EC_RESOURCE_EXHAUSTED)
                .setMessage("Request rate limit exceeded"));
        dropToYt(t);
        track(t, TaskEvent.DROPPED);
        t.onComplete();
    }

    private void onConcurrentLimitTaskDroppedAction(Task t) {
        logger.info("Task {}: dropped due to concurrency limit violation", t.getId());
        t.onError(TError.newBuilder()
                .setCode(EErrorCode.EC_RESOURCE_EXHAUSTED)
                .setMessage("Too many concurrent requests"));
        dropToYt(t);
        track(t, TaskEvent.DROPPED);
        t.onComplete();
    }

    private void startTasks(List<Task> tasks) {
        tasks.forEach(task -> {
            if (!queueConsumers.get(task.getRequest().getRequestClass()).getQueue().offer(task)) {
                task.onError(TError.newBuilder()
                        .setCode(EErrorCode.EC_RESOURCE_EXHAUSTED)
                        .setMessage("Request queue size limit exceeded"));
                track(task, TaskEvent.DROPPED);
                logger.info("Task {}: dropped due to {} queue size limit violation",
                        task.getId(),
                        task.getRequest().getRequestClass());
                dropToYt(task);
                task.onComplete();
            }
        });
        queueConsumers.forEach((requestClass, qe) -> qe.trySchedulingQueueConsumer());
    }

    private void startNow(Task.GroupingKey groupingKey, List<Task> batch, QueueConsumer queueConsumer) {
        var throttler = queueConsumer.getThrottler();
        if (!runningTaskTerminationSemaphore.acquire()) {
            batch.forEach(task -> {
                String message = String.format("Task %s: failed to start as shutdown has been initiated", task.getId());
                logger.error(message);
                task.onError(TError.newBuilder()
                        .setCode(EErrorCode.EC_ABORTED)
                        .setMessage(message));
                dropToYt(task);
                track(task, TaskEvent.DROPPED);
                task.onComplete();
            });
            return;
        }
        long batchNum = batchCount.incrementAndGet();
        logger.info("Will execute {} tasks in a batch #{}", batch.size(), batchNum);
        batchDistributionSummaries.get(queueConsumer.getRequestClass()).record(batch.size());
        batch.forEach(task -> {
            track(task, TaskEvent.STARTED);
            logger.debug("Task {}: handling started as '{}'", task.getId(), queueConsumer.getRequestClass());
        });

        var firstTry = new AtomicBoolean(true);
        String requestId = generateIdForRequest(batch);
        var executionResult = retryHelper.withSpeculativeRetry(
                "PartnerRequest",
                ignored -> {
                    if (firstTry.compareAndSet(true, false)) {
                        return execute(groupingKey, batch, requestId);
                    } else {
                        // speculative retries are only used for interactive requests
                        var decision = throttler.acquire(System.currentTimeMillis());
                        switch (decision) {
                            case RATE_LIMIT:
                            case CONCURRENCY_LIMIT:
                                throw new RateLimitException();
                            case PASS:
                                try {
                                    speculativeRetryCounter.increment();
                                    return execute(groupingKey, batch, requestId);
                                } finally {
                                    throttler.release();
                                }
                        }
                        throw new IllegalStateException("Unexpected throttler decision: " + decision);
                    }
                },
                null,
                SpeculativeRetryStrategy.<Void>builder()
                        .shouldRetryOnException(e -> e instanceof RateLimitException)
                        .retryDelay(config.getSpeculativeRetryTimeout())
                        .numRetries(queueConsumer.getRequestClass() == ERequestClass.RC_INTERACTIVE ?
                                config.getSpeculativeRetriesCount() + 1 : 1) // It's actually numTries, not numRetries
                        .build(),
                speculativeRetryRateLimiter);

        executionResult.whenComplete((ignored, throwable) -> {
                    throttler.release();
                    if (throwable != null) {
                        batch.forEach(task -> {
                            logger.error(String.format("Task %s: handling completed exceptionally", task.getId()),
                                    throwable);
                            task.onError(ProtoUtils.errorFromThrowable(throwable, task.isIncludeDebug()));
                        });
                    } else {
                        batch.forEach(task -> {
                            TError taskError = task.getOfferError();
                            if (taskError != null) {
                                logger.error("Task {}: handling completed exceptionally: {}", task.getId(),
                                        taskError.getMessage());
                            } else {
                                logger.debug("Task {}: handling completed successfully with {} offers", task.getId(),
                                        task.getOfferCount());
                                task.setCacheLifetime(coldService.getOfferLifetime(task,
                                        config.getMaxNonEmptyLifetime()));
                                if (task.isEmpty()) {
                                    emptyOfferListsCounter.increment();
                                } else {
                                    unknownPansionOffersCounter.increment(task.getUnknownPansionCount());
                                }
                                offerCounters.get(queueConsumer.getRequestClass()).increment(task.getOfferCount());
                            }
                        });
                    }
                    CompletableFuture<Void> storageFuture = putDataToKeyValueStorage(batch);
                    CompletableFuture<Void> searchFlowOfferDataStorageFuture =
                            putDataToSearchFlowOfferDataStorage(batch);
                    CompletableFuture.allOf(storageFuture, searchFlowOfferDataStorageFuture).whenComplete((ignored2,
                                                                                                           throwable2) -> {
                                // KV's throwable is ignored, because we have offers anyway
                                batch.forEach(task -> {
                                    logResults(task);
                                    task.onComplete();
                                    if (task.hasError()) {
                                        track(task, TaskEvent.FAILED);
                                    } else {
                                        track(task, TaskEvent.COMPLETED);
                                    }
                                });
                                List<CompletableFuture<Void>> ytWriteFutures = new ArrayList<>(batch.size());
                                batch.forEach(task -> ytWriteFutures.add(exportToYT(throwable, task, false)));
                                CompletableFuture.allOf(ytWriteFutures.toArray(new CompletableFuture[0]))
                                        .whenComplete((ignore, thr) -> {
                                            runningTaskTerminationSemaphore.release();
                                            if (thr != null) {
                                                logger.warn("Could not properly complete all YT writes", thr);
                                            }
                                        });
                                logger.info("Batch #{} completed", batchNum);
                            })
                            .exceptionally(t -> {
                                String taskIds = batch.stream().map(Task::getId).collect(Collectors.joining(", "));
                                logger.error("Unexpected exception while handling tasks {}", taskIds, t);
                                return null;
                            });
                })
                .exceptionally(t -> {
                    String taskIds = batch.stream().map(Task::getId).collect(Collectors.joining(", "));
                    logger.error("Unexpected exception while processing execution result for tasks {}. Exception: {}"
                            , taskIds, t.getMessage(), t);
                    return null;
                });
    }

    private CompletableFuture<Void> putDataToKeyValueStorage(List<Task> batch) {
        List<CompletableFuture<Void>> futures = new ArrayList<>(batch.size());
        for (Task task : batch) {
            futures.add(putTaskMessages(task).whenComplete((res, t) -> {
                track(task, TaskEvent.KV_PUT_FINISHED);
                if (t != null) {
                    logger.error("Unable to put task's messages to KeyValue storage", t);
                    task.onError(ProtoUtils.errorFromThrowable(t, task.isIncludeDebug()));
                }
            }));
        }
        return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
    }

    private CompletableFuture<Void> putDataToSearchFlowOfferDataStorage(List<Task> batch) {
        return CompletableFuture.allOf(batch.stream()
                .map(task -> putTaskDataToSearchFlowOfferDataStorage(task)
                        .whenComplete((res, t) -> {
                            if (t != null) {
                                track(task, TaskEvent.SEARCH_FLOW_KV_PUT_FAILED);
                                logger.error("Unable to put task's messages to KeyValue storage", t);
                                task.onError(ProtoUtils.errorFromThrowable(t, task.isIncludeDebug()));
                            } else {
                                track(task, TaskEvent.SEARCH_FLOW_KV_PUT_COMPLETED);
                            }
                        }))
                .toArray(CompletableFuture[]::new));
    }

    private CompletableFuture<Void> putTaskDataToSearchFlowOfferDataStorage(Task task) {
        track(task, TaskEvent.SEARCH_FLOW_KV_PUT_STARTED);
        return CompletableFuture.allOf(task.getOfferList().stream()
                .map(offer -> searchFlowOfferDataStorageSupplier.get()
                        .thenCompose(storage -> {
                            var message = TOfferDataMessage.newBuilder()
                                    .setOfferId(offer.getId());
                            message.setLandingInfo(offer.getLandingInfo());
                            return storage.put(offer.getId(), message.build(), config.getSearchFlowOfferDataLifetime());
                        })
                        .whenComplete((r, t) -> {
                            if (t != null) {
                                logger.error("Task {}: error while putting data of Offer {} to Search flow KV",
                                        task.getId(), offer.getId(), t);
                            }
                        })
                )
                .toArray(CompletableFuture[]::new));
    }

    private CompletableFuture<Void> putTaskMessages(Task task) {
        track(task, TaskEvent.KV_PUT_STARTED);
        Map<String, Message> messagesToPut = task.getMessagesToPutToStorage();
        List<CompletableFuture> futureList = new ArrayList<>(messagesToPut.size());
        if (!messagesToPut.isEmpty()) {
            logger.debug("Task {}: putting {} items to KeyValue storage", task.getId(), messagesToPut.size());
        }
        for (Map.Entry<String, Message> item : messagesToPut.entrySet()) {
            futureList.add(
                    storageSupplier.get()
                            .thenCompose(storage -> storage.put(item.getKey(), item.getValue(), null))
                            .whenComplete((r, t) -> {
                                if (t != null) {
                                    logger.error("Task {}: error while putting OfferData {} to KV",
                                            task.getId(), item.getKey(), t);
                                }
                            }));
        }
        return CompletableFuture.allOf(futureList.toArray(new CompletableFuture[0]));
    }

    abstract CompletableFuture<Void> execute(Task.GroupingKey groupingKey, List<Task> task, String requestId);

    RequestBuilder buildHttpRequest(List<Task> tasks, String requestId) {
        return new RequestBuilder()
                .setUrl(config.getBaseUrl())
                .setReadTimeout(Math.toIntExact(config.getHttpReadTimeout().toMillis()))
                .setRequestTimeout(Math.toIntExact(config.getHttpRequestTimeout().toMillis()))
                .setHeader(HTTP_REQUEST_ID_HEADER, requestId);
    }

    String generateIdForRequest(List<Task> tasks) {
        String requestId = ProtoUtils.randomId();
        if (tasks.isEmpty()) {
            logger.debug("Preparing auxiliary request {}", requestId);
        } else {
            if (logger.isDebugEnabled()) {
                String taskIds = tasks.stream().map(Task::getId).collect(Collectors.joining(", "));
                logger.debug("Request {} will handle tasks {}", requestId, taskIds);
            }
            tasks.forEach(t -> t.setHttpRequestId(requestId));
        }
        return requestId;
    }

    CompletableFuture<Response> runHttpRequest(Request request, boolean only2xx, String purpose,
                                               ERequestClass requestClass) {
        Timer.Sample started = Timer.start(Metrics.globalRegistry);
        String httpRequestId = request.getHeaders().get(HTTP_REQUEST_ID_HEADER);
        logger.debug("{} HTTP Request {}: Sending HTTP request: {}", purpose, httpRequestId, request.getUrl());
        if (isMocked()) {
            logger.warn("HTTP Request {} will be done to mocked API", httpRequestId);
        }
        CompletableFuture<Response> future = client.executeRequest(request).toCompletableFuture();
        return future.whenComplete((response, throwable) -> {
            try {
                started.stop(getHttpTimer(purpose, requestClass));
                started.stop(getHttpTimer(ALL_PURPOSES, requestClass));
            } catch (IllegalArgumentException e) {
                logger.warn("Unable to measure HTTP time due to unknown purpose");
            }
            HttpStatus status;
            if (throwable != null) {
                if (throwable.getClass().equals(TimeoutException.class)) {
                    status = HttpStatus.TIMEOUT;
                } else {
                    status = HttpStatus.ERROR;
                }
                logger.warn(String.format("%s HTTP request %s: Request failed", purpose, httpRequestId), throwable);
                if (!isMocked()) {
                    httpLogger.logRequestResponse(request, null, purpose, throwable);
                }
                trackHttpRequest(purpose, status, requestClass);
            } else {
                logger.debug("{} HTTP Request {}: Request completed; Code {}", purpose, httpRequestId,
                        response.getStatusCode());
                if (!isMocked()) {
                    httpLogger.logRequestResponse(request, response, purpose, null);
                }
                trackHttpByte(purpose, response.getResponseBodyAsBytes().length, requestClass);
                status = integerToStatus.get(response.getStatusCode() / 100);
                if (status != null) {
                    trackHttpRequest(purpose, status, requestClass);
                }
            }
            if (response != null && only2xx && (response.getStatusCode() < 200 || response.getStatusCode() >= 300)) {
                logger.error(response.getResponseBody());
                throw new RuntimeException("Bad HTTP status code: " + response.getStatusCode());
            }
        });
    }

    /////////////////
    private CompletableFuture<Void> exportToYT(Throwable error, Task task, boolean placeholder) {
        if (task.getCallContext().getTestContext() != null) {
            logger.debug("Task has a test context, so export to YT is skipped");
            return CompletableFuture.completedFuture(null);
        }
        logger.debug("Task {}: exporting {} to YT", task.getId(), placeholder ? "placeholder" : "result");
        TSearcherMessage.Builder builder = TSearcherMessage.newBuilder();
        builder.setRequest(task.getRequest().toBuilder()
                .setAttribution(TRequestAttribution.newBuilder().setOfferCacheClientId(task.getRequest().getAttribution().getOfferCacheClientId()).build())
                .build());
        if (error != null) {
            task.onError(ProtoUtils.errorFromThrowable(error, task.isIncludeDebug()));
        }
        if (task.hasError()) {
            task.setCacheLifetime(config.getErrorLifetime());
            task.dumpResultTo(builder.getResponseBuilder());
        } else if (placeholder) {
            task.setCacheLifetime(config.getPlaceholderLifetime());
            builder.getResponseBuilder().setPlaceholder(TPlaceholder.newBuilder());
        } else {
            task.dumpResultTo(builder.getResponseBuilder());
            Preconditions.checkState(builder.getResponseBuilder().getOffersOrBuilder().getCacheTimeSec() != null);
        }

        if (builder.getResponseBuilder().hasOffers()) {
            builder.getResponseBuilder().getOffersBuilder().getOfferBuilderList()
                    .forEach(offerBuilder -> {
                        offerBuilder.clearLandingInfo();
                        offerBuilder.clearAvailabilityGroupKey();
                        offerBuilder.clearExternalId();
                        offerBuilder.clearWifiIncluded();
                    });
        }

        Preconditions.checkState(task.getCacheLifetime() != null);
        CompletableFuture<Void> result = messageBus.send(builder.build(), task.getCacheLifetime());
        result.whenComplete((r, t) -> logger.debug("Task {}: completed exporting {} to YT", task.getId(),
                placeholder ? "placeholder" : "result"));
        return result;
    }

    private void dropToYt(Task task) {
        if (droppedTaskTerminationSemaphore.acquire()) {
            exportToYT(null, task, false).whenComplete((e, t) -> droppedTaskTerminationSemaphore.release());
        }
    }

    private JsonObject getJsonLogMessage(Task task) {
        JsonObject message = new JsonObject();
        message.addProperty("Timestamp", System.currentTimeMillis());
        if (env.getActiveProfiles().length > 0) {
            message.addProperty("Environment", env.getActiveProfiles()[0]);
        } else {
            message.addProperty("Environment", "dev");
        }
        message.addProperty("IsMocked", isMocked());
        return message;
    }

    private void logResults(Task task) {
        JsonObject message = getJsonLogMessage(task);
        JsonParser prs = new JsonParser();
        try {
            message.add("Request", prs.parse(jsonPrinter.print(task.getRequest())));
            message.add("Result", prs.parse(jsonPrinter.print(task.dumpResult())));
            message.addProperty("HttpRequestId", task.getHttpRequestId());
            message.addProperty("IsSuccessful", !task.hasError());
            resultLogger.info(message.toString());
        } catch (InvalidProtocolBufferException e) {
            logger.error("Unable to log results to logbroker", e);
        }
    }

    protected Map<String, List<Task>> mapTasksByOriginalId(List<Task> tasks) {
        return tasks.stream().collect(Collectors.groupingBy(t -> t.getRequest().getHotelId().getOriginalId()));
    }

    private boolean isMocked() {
        return config.getBaseUrl().startsWith("http://localhost:4242");
    }

    protected List<String> getHttpCallPurposes() {
        return Collections.singletonList("main");
    }

    private void track(Task task, TaskEvent event) {
        ERequestClass clazz = task.getRequest().getRequestClass();
        taskCounters.get(event).get(clazz).increment();
        switch (event) {
            case STARTED:
                task.setStartedAtNanos(System.nanoTime());
                taskStallTimers.get(clazz).record(task.getStartedAtNanos() - task.getCreatedAtNanos(),
                        TimeUnit.NANOSECONDS);
                break;
            case COMPLETED:
            case FAILED:
                task.setCompletedAtNanos(System.nanoTime());
                taskTotalTimers.get(clazz).record(task.getCompletedAtNanos() - task.getCreatedAtNanos(),
                        TimeUnit.NANOSECONDS);
                break;
        }
    }

    /////////////////

    private void trackHttpRequest(String purpose, HttpStatus status, ERequestClass clazz) {
        // Two counters to precompute purpose-aggregated value.
        if (!httpCounters.get(clazz).containsKey(purpose)) {
            logger.warn("Purpose '{}' is unknown for metrics", purpose);
            return;
        }
        httpCounters.get(clazz).get(purpose).get(status).increment();
        httpCounters.get(clazz).get(ALL_PURPOSES).get(status).increment();
    }

    private void trackHttpByte(String purpose, int delta, ERequestClass clazz) {
        // Two counters to precompute purpose-aggregated value.
        if (!httpByteCounters.get(clazz).containsKey(purpose)) {
            logger.warn("Purpose '{}' is unknown for metrics", purpose);
            return;
        }
        httpByteCounters.get(clazz).get(purpose).increment(delta);
        httpByteCounters.get(clazz).get(ALL_PURPOSES).increment(delta);
    }

    private Timer getHttpTimer(String purpose, ERequestClass clazz) {
        if (!httpTimers.get(clazz).containsKey(purpose)) {
            logger.warn("Purpose '{}' is unknown for metrics", purpose);
            throw new IllegalArgumentException("Unknown Purpose");
        }
        return httpTimers.get(clazz).get(purpose);
    }

    protected void onOffer(Task task, TOffer.Builder offerBuilder) {
        if (validatePrice(task, offerBuilder.getPrice().getAmount())) {
            task.onOfferImpl(offerBuilder);
        }
    }

    protected boolean validatePrice(Task task, double price) {
        if (!isValidPrice(price)) {
            logger.warn("Task {}: Invalid price {} for hotel_id '{}'",
                    task.getId(),
                    price,
                    task.getRequest().getHotelId().getOriginalId());
            invalidOfferPriceCounter.increment();
            return false;
        }
        return true;
    }

    private boolean isValidPrice(double price) {
        return price > 0 && price < MAX_VALID_PRICE;
    }

    protected List<TRefundRule> mapToProtoRefundRules(RefundRules refundRules) {
        return refundRules.actualize().getRules().stream().map(rule -> {
            TRefundRule.Builder refundRuleBuilder = TRefundRule.newBuilder()
                    .setType(toProtoRefundType(rule.getType()));
            if (rule.getPenalty() != null) {
                refundRuleBuilder.setPenalty(ProtoUtils.toTPrice(rule.getPenalty()));
            }
            if (rule.getStartsAt() != null) {
                refundRuleBuilder.setStartsAt(ProtoUtils.timestamp(rule.getStartsAt().toEpochMilli()));
            }
            if (rule.getEndsAt() != null) {
                refundRuleBuilder.setEndsAt(ProtoUtils.timestamp(rule.getEndsAt().toEpochMilli()));
            }
            return refundRuleBuilder.build();
        }).collect(Collectors.toList());
    }

    private ERefundType toProtoRefundType(RefundType refundType) {
        switch (refundType) {
            case FULLY_REFUNDABLE:
                return ERefundType.RT_FULLY_REFUNDABLE;
            case REFUNDABLE_WITH_PENALTY:
                return ERefundType.RT_REFUNDABLE_WITH_PENALTY;
            case NON_REFUNDABLE:
                return ERefundType.RT_NON_REFUNDABLE;
            default:
                throw new RuntimeException("Unknown refundType: " + refundType);
        }
    }

    @Override
    public void shutdown() {
        runningTaskTerminationSemaphore.shutdown();
        List<Task> backgroundTaskList = queueConsumers.get(ERequestClass.RC_BACKGROUND).getQueue().drain();
        backgroundTaskList.forEach(t -> {
            t.onError(TError.newBuilder()
                    .setCode(EErrorCode.EC_ABORTED)
                    .setMessage("Shutdown has been initiated"));
            logger.info("Task {}: dropped due to shutdown", t.getId());
            track(t, TaskEvent.DROPPED);
            t.onComplete();
            dropToYt(t);
        });
        droppedTaskTerminationSemaphore.shutdown();
    }

    @Override
    public void awaitTermination() throws InterruptedException {
        runningTaskTerminationSemaphore.awaitTermination();
        droppedTaskTerminationSemaphore.awaitTermination();
        logger.info("Handling of all the tasks has been completed, closing AHC");
        try {
            client.close();
        } catch (IOException e) {
            logger.error("Unable to close AHC", e);
        }
        queueConsumers.forEach((requestClass, executor) -> {
                    List<Runnable> terminatedRunnables = executor.shutdownNow();
                    if (!terminatedRunnables.isEmpty()) {
                        logger.warn("Shutdown killed {} running tasks for {}",
                                terminatedRunnables.size(), requestClass);
                    }
                }
        );
        List<Runnable> terminatedRunnables = executor.shutdownNow();
        if (!terminatedRunnables.isEmpty()) {
            logger.warn("Shutdown killed {} running tasks in default executor", terminatedRunnables.size());
        }
        logger.info("Partner terminated");
    }
    /////////////////

    @Override
    public void destroy() throws Exception {
        shutdown();
        awaitTermination();
    }

    enum TaskEvent {
        ACCEPTED,
        REJECTED,
        DROPPED,
        STARTED,
        KV_PUT_STARTED,
        KV_PUT_FINISHED,
        SEARCH_FLOW_KV_PUT_STARTED,
        SEARCH_FLOW_KV_PUT_FAILED,
        SEARCH_FLOW_KV_PUT_COMPLETED,
        COMPLETED,
        FAILED
    }

    enum HttpStatus {
        HTTP_2XX,
        HTTP_3XX,
        HTTP_4XX,
        HTTP_5XX,
        TIMEOUT,
        ERROR,
    }
}
