package ru.yandex.travel.commons.logging.ydb;

import java.time.Duration;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;

import com.google.common.base.Strings;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.core.LogEvent;
import org.apache.logging.log4j.core.appender.AbstractManager;
import org.apache.logging.log4j.core.async.InternalAsyncUtil;
import org.apache.logging.log4j.core.impl.Log4jLogEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.travel.commons.metrics.MetricsUtils;
import ru.yandex.travel.commons.network.NetworkUtils;
import ru.yandex.travel.logging.ydb.TOrderLogRecord;

import static java.util.stream.Collectors.toList;
import static ru.yandex.travel.commons.logging.CommonMdcParams.MDC_ENTITY_ID;

public class YdbLogManager extends AbstractManager {
    private static final String YDB_UTIL_OWNER_ID = "c347f21b-a792-4ee6-9331-6301f780aff7";

    private static final Counter sentEventsMeter =
            Counter.builder("logs.ydb.sentEvents").register(Metrics.globalRegistry);
    private static final Counter rejectedEventsMeter =
            Counter.builder("logs.ydb.rejectedEvents").register(Metrics.globalRegistry);
    private static final Counter failedBatchesMeter =
            Counter.builder("logs.ydb.failedBatches").register(Metrics.globalRegistry);
    private static final Timer processingTimeMeter = Timer.builder("logs.ydb.processingTime")
            .serviceLevelObjectives(List.of(4, 8, 16, 32, 64, 128, 256).stream().map(Duration::ofMillis).collect(toList()).toArray(new Duration[0]))
            .publishPercentiles(MetricsUtils.higherPercentiles())
            .register(Metrics.globalRegistry);
    private static final DistributionSummary batchSizeMeter = DistributionSummary.builder("logs.ydb.batchSize")
            .serviceLevelObjectives(1, 5, 25, 100)
            .publishPercentiles(.50, .90, .95, .99, 1)
            // solomon polls the stats every 15 seconds, so we keep 3 full buffers (3*5=15 sec) + 1 currently being
            // filled (0-5 sec)
            .distributionStatisticBufferLength(4)
            .distributionStatisticExpiry(Duration.ofSeconds(5))
            .register(Metrics.globalRegistry);

    private final BlockingQueue<LogEvent> queue;
    private final ExecutorService asyncExecutor;
    private final YdbLogTableClient logTableClient;
    private final int batchSize;
    private final Duration clientTimeout;
    private final Duration shutdownTimeout;
    private final AtomicBoolean isActive;
    private final YdbLogRecordFactory ydbLogRecordFactory;

    private static final Logger log = LoggerFactory.getLogger("ru.yandex.travel.commons.logging.ydb");

    YdbLogManager(String name, YdbLogTableClient logTableClient, YdbLogProperties properties) {
        super(null, name);
        this.queue = new ArrayBlockingQueue<>(properties.getQueueSize());
        this.asyncExecutor = Executors.newSingleThreadExecutor(new ThreadFactoryBuilder()
                .setNameFormat("ydb-async-log")
                .build());
        this.logTableClient = logTableClient;
        this.batchSize = properties.getBatchSize();
        this.clientTimeout = properties.getClientTimeout();
        this.shutdownTimeout = properties.getShutdownTimeout();
        this.isActive = new AtomicBoolean(false);
        this.ydbLogRecordFactory = new YdbLogRecordFactory(NetworkUtils.getLocalHostName());

        Gauge.builder("logs.ydb.queueSize", queue::size).register(Metrics.globalRegistry);
    }

    protected void start() {
        log.info("Starting YdbLogManager");
        if (!isActive.compareAndSet(false, true)) {
            // multiple appender instances can call this method on the manager singleton
            log.info("The manager has already been started");
            return;
        }
        asyncExecutor.submit(this::runQueueProcessingLoop);
        try {
            // testing connection for logging purposes only, failure to connect should not prevent the app from starting
            logTableClient.insertLogRecords(List.of(TOrderLogRecord.newBuilder()
                    .setOwnerId(YDB_UTIL_OWNER_ID)
                    .setTimestamp(System.currentTimeMillis())
                    .setMessageId(UUID.randomUUID().toString())
                    .setLogger(getClass().getName())
                    .setLevel(Level.INFO.name())
                    .setMessage("Connection test: " + LocalDateTime.now())
                    .setContext("{}")
                    .build())
            ).get(clientTimeout.toMillis(), TimeUnit.MILLISECONDS);
        } catch (Exception e) {
            // the code shouldn't fail startup as in that case the releaseSub method will never be called
            log.warn("Failed to test YDB connectivity", e);
        }
    }

    // the method is called on manager shutdown
    @Override
    protected boolean releaseSub(long timeout, TimeUnit timeUnit) {
        log.info("Stopping YdbLogManager");
        if (!isActive.compareAndSet(true, false)) {
            // multiple shutdown attempts are ok
            log.info("The manager has already been stopped");
            return true;
        }
        boolean withErrors = false;

        // sending an InterruptedException ahead not to waste any time in 'soft' shutdown phase
        // (there is only 1 working thread that won't stop without the signal)
        asyncExecutor.shutdownNow();
        if (!MoreExecutors.shutdownAndAwaitTermination(asyncExecutor, shutdownTimeout.toMillis(),
                TimeUnit.MILLISECONDS)) {
            log.error("Failed to terminate the queue processing task in time; timeout={}", shutdownTimeout);
            withErrors = true;
        }

        try {
            logTableClient.close();
        } catch (Exception e) {
            log.error("Failed to close the ydb client", e);
            withErrors = true;
        }

        log.info("YdbLogManager is stopped; withErrors={}", withErrors);
        return !withErrors;
    }

    protected void write(LogEvent event) {
        if (!isActive.get()) {
            throw new IllegalStateException("The manager hasn't been started or has already been closed; event=" + event);
        }
        // reference impl: org.apache.logging.log4j.core.appender.AsyncAppender
        Log4jLogEvent memento = Log4jLogEvent.createMemento(event, false);
        InternalAsyncUtil.makeMessageImmutable(event.getMessage());
        if (queue.offer(memento)) {
            sentEventsMeter.increment();
        } else {
            // the queue is overflown
            rejectedEventsMeter.increment();
        }
    }

    // reference impl: org.apache.logging.log4j.core.appender.AsyncAppender
    private void runQueueProcessingLoop() {
        log.info("Starting YdbLogManager.runQueueProcessingLoop");
        // the are two expected ways to exit the main processing loop:
        // - to detected the isActive flag set to false
        // - to receive an InterruptedException on manager shutdown
        while (isActive.get()) {
            try {
                drainQueue();
            } catch (InterruptedException e) {
                // YdbLogManager(shutdownNow) tells us to stop
                break;
            } catch (Exception e) {
                log.error("Batch failed: failed to drain some log events from the queue", e);
                failedBatchesMeter.increment();
            }
        }

        log.info("Finishing YdbLogManager.runQueueProcessingLoop. Remaining events to process: {}", queue.size());
        while (!queue.isEmpty()) {
            try {
                drainQueue();
            } catch (Exception e) {
                log.error("Batch failed, ignoring pre-termination error", e);
                failedBatchesMeter.increment();
            }
        }
        log.info("YdbLogManager.runQueueProcessingLoop is finished");
    }

    protected void drainQueue() throws InterruptedException, ExecutionException, TimeoutException {
        List<TOrderLogRecord> events = takeEventsFromQueue(batchSize).stream()
                .filter(e -> !Strings.isNullOrEmpty(e.getContextData().getValue(MDC_ENTITY_ID)))
                .map(ydbLogRecordFactory::createFromLogEvent)
                .collect(toList());

        log.trace("Sending batch: {}", events.size());
        if (!events.isEmpty()) {
            batchSizeMeter.record(events.size());
            long ts1 = System.currentTimeMillis();
            try {
                getWithDelayedInterruption(logTableClient.insertLogRecords(events));
            } finally {
                long ts2 = System.currentTimeMillis();
                processingTimeMeter.record(ts2 - ts1, TimeUnit.MILLISECONDS);
            }
        }
    }

    /**
     * For operations that should be interruptable but only after the client timeout has happened.
     * We have to do so to overcome the following error:
     * <pre><code>Non retryable status: Status{code=PRECONDITION_FAILED, issues=[Pending previous query completion (S_ERROR)]}</code></pre>
     */
    private <T> T getWithDelayedInterruption(Future<T> f) throws ExecutionException, TimeoutException {
        long start = System.currentTimeMillis();
        long timeLeft;
        InterruptedException interruption = null;
        T result = null;
        boolean success = false;
        while (!success && (timeLeft = (start + clientTimeout.toMillis()) - System.currentTimeMillis()) > 0) {
            try {
                result = f.get(timeLeft, TimeUnit.MILLISECONDS);
                success = true;
            } catch (InterruptedException e) {
                if (interruption == null) {
                    interruption = e;
                }
            }
        }
        if (interruption != null) {
            Thread.currentThread().interrupt();
        }
        if (success) {
            return result;
        } else {
            throw new TimeoutException("The operation has timed out; timeout=" + clientTimeout);
        }
    }

    private List<LogEvent> takeEventsFromQueue(int maxEvents) throws InterruptedException {
        List<LogEvent> messages = new ArrayList<>();
        // we block here until an event is available (or an InterruptedException happens)
        messages.add(queue.take());
        // and then using the non-blocking api to fill the batch if possible
        queue.drainTo(messages, maxEvents - 1);
        return messages;
    }
}
