package ru.yandex.bannerstorage.messaging.services;

import java.sql.Timestamp;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bannerstorage.messaging.services.exceptions.AbortMessageProcessingException;

/**
 * @author egorovmv
 */
public abstract class AbstractQueueObserver implements QueueObserver {
    protected static final int DEFAULT_POLL_INTERVAL_IN_MS = 30 * 1000;
    private static final int DEFAULT_RECEIVE_MESSAGES_TIMEOUT_IN_MS = 3000;

    private static final String END_DIALOG_MESSAGE_TYPE = "http://schemas.microsoft.com/SQL/ServiceBroker/EndDialog";

    private final Logger logger;
    private final String queueId;
    private final int pollIntervalInMS;
    private final int batchSize;
    private final int countOfThreads;
    private final QueueMessageOnErrorStrategy errorStrategy;
    private final ReceiveMessageStrategy receiveStrategy;
    private final ExecutorService executorService;
    private volatile boolean isRunning;

    AbstractQueueObserver(
            @NotNull String queueId,
            int pollIntervalInMS,
            int batchSize,
            int countOfThreads,
            @NotNull QueueMessageOnErrorStrategy errorStrategy) {
        this(queueId, pollIntervalInMS, batchSize, countOfThreads, errorStrategy, ReceiveMessageStrategy.RECEIVE_UNTIL_EMPTY);
    }

    AbstractQueueObserver(
            @NotNull String queueId,
            int pollIntervalInMS,
            int batchSize,
            @NotNull QueueMessageOnErrorStrategy errorStrategy) {
        this(queueId, pollIntervalInMS, batchSize, 1, errorStrategy, ReceiveMessageStrategy.RECEIVE_UNTIL_EMPTY);
    }

    protected AbstractQueueObserver(
            @NotNull String queueId,
            int pollIntervalInMS,
            int batchSize,
            @NotNull QueueMessageOnErrorStrategy errorStrategy,
            @NotNull ReceiveMessageStrategy receiveStrategy) {
        this(queueId, pollIntervalInMS, batchSize, 1, errorStrategy, receiveStrategy);
    }

    private AbstractQueueObserver(
            @NotNull String queueId,
            int pollIntervalInMS,
            int batchSize,
            int countOfThreads,
            @NotNull QueueMessageOnErrorStrategy errorStrategy,
            @NotNull ReceiveMessageStrategy receiveStrategy) {
        if (pollIntervalInMS <= 0)
            throw new IllegalArgumentException("pollIntervalInMS");
        if (batchSize <= 0)
            throw new IllegalArgumentException("batchSize");
        this.logger = LoggerFactory.getLogger(getClass());
        this.queueId = Objects.requireNonNull(queueId, "queueId");
        this.pollIntervalInMS = pollIntervalInMS;
        this.batchSize = batchSize;
        this.countOfThreads = countOfThreads;
        this.errorStrategy = Objects.requireNonNull(errorStrategy, "errorStrategy");
        this.receiveStrategy = Objects.requireNonNull(receiveStrategy, "receiveStrategy");

        if (countOfThreads == 1) {
            this.executorService = null;
        } else {
            this.executorService = Executors.newFixedThreadPool(
                    countOfThreads,
                    new ThreadFactoryBuilder()
                            .setNameFormat(getClass().getSimpleName() + "-observer-%d")
                            .build()
            );
        }
    }

    protected final Logger getLogger() {
        return logger;
    }

    @NotNull
    @Override
    public final String getQueueId() {
        return queueId;
    }

    @Override
    public final int getPollIntervalInMS() {
        return pollIntervalInMS;
    }

    public void doStart() {
    }

    @Override
    public final void start() {
        doStart();
        isRunning = true;
    }

    protected abstract void doProcessMessage(
            @NotNull QueueOperations queueOperations, @NotNull QueueMessage message);

    @Override
    public final boolean processMessages(
            @NotNull QueueOperations queueOperations, @NotNull Map<String, Object> localState) {
        Timestamp startSnapshotTime = Timestamp.from(Instant.now().minusMillis(pollIntervalInMS));

        List<QueueMessage> messages = queueOperations.receiveMessages(
                getQueueId(), batchSize, DEFAULT_RECEIVE_MESSAGES_TIMEOUT_IN_MS);
        if (messages.isEmpty())
            return false;

        List<QueueMessage> workingMessages = new ArrayList<>();
        List<QueueMessage> closingMessages = new ArrayList<>();

        for (QueueMessage message : messages) {
            if (message.getMessageType().equalsIgnoreCase(END_DIALOG_MESSAGE_TYPE))
                if (countOfThreads == 1) {
                    queueOperations.endSession(message);
                } else {
                    closingMessages.add(message);
                }
            else {
                if (receiveStrategy == ReceiveMessageStrategy.SNAPSHOT && message.getEnqueuedTime().after(startSnapshotTime))
                    throw new AbortMessageProcessingException();

                if (countOfThreads == 1) {
                    doProcessWorkingMessage(queueOperations, message);
                } else {
                    workingMessages.add(message);
                }
            }
        }

        if (countOfThreads > 1) {
            CompletableFuture.allOf(
                    workingMessages.stream()
                            .map(message -> CompletableFuture.runAsync(() -> doProcessWorkingMessage(queueOperations, message), executorService))
                            .toArray(CompletableFuture[]::new)).join(); // IGNORE-BAD-JOIN DIRECT-149116

            closingMessages.forEach(queueOperations::endSession);
        }

        return true;
    }

    private void doProcessWorkingMessage(@NotNull QueueOperations queueOperations, QueueMessage message) {
        try {
            logger.info("Processing message (MessageId: \"{}\")...", message.getMessageId());
            doProcessMessage(queueOperations, message);
            logger.info("Message processed (MessageId: \"{}\")", message.getMessageId());
        } catch (Throwable e) {
            logger.error(
                    String.format("Can't process message (MessageId: \"%s\")", message.getMessageId()), e);
            errorStrategy.processError(queueOperations, message, e);
        }
    }

    @Override
    public void close() {
        isRunning = false;
    }

    @Override
    public final boolean isRunning() {
        return isRunning;
    }
}
