package ru.yandex.solomon.coremon.meta.db;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.OptionalLong;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import javax.annotation.Nullable;

import com.google.common.collect.Multimap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.misc.actor.ActorRunnerImpl;
import ru.yandex.misc.actor.Tasks;
import ru.yandex.misc.concurrent.CompletableFutures;
import ru.yandex.monitoring.coremon.EShardState;
import ru.yandex.monlib.metrics.labels.Label;
import ru.yandex.monlib.metrics.labels.Labels;
import ru.yandex.solomon.core.ShardIsNotLocalException;
import ru.yandex.solomon.core.urlStatus.UrlStatusTypeException;
import ru.yandex.solomon.coremon.meta.CoremonMetric;
import ru.yandex.solomon.coremon.meta.CoremonMetricArray;
import ru.yandex.solomon.coremon.meta.MetricMeta;
import ru.yandex.solomon.coremon.meta.file.FileMetricsCollection;
import ru.yandex.solomon.coremon.meta.mem.MemOnlyMetricsCollectionImpl;
import ru.yandex.solomon.labels.LabelKeys;
import ru.yandex.solomon.memory.layout.MemMeasurable;
import ru.yandex.solomon.memory.layout.MemoryCounter;
import ru.yandex.solomon.proto.UrlStatusType;
import ru.yandex.solomon.util.async.InFlightLimiter;
import ru.yandex.solomon.util.collection.queue.ArrayListLockQueueMem;
import ru.yandex.solomon.util.time.InstantUtils;
import ru.yandex.stockpile.client.shard.StockpileLocalId;


/**
 * @author Sergey Polovko
 */
public class MetabaseShardStorageImpl implements MetabaseShardStorage {
    private static final Logger logger = LoggerFactory.getLogger(MetabaseShardStorageImpl.class);

    private static final long DEFAULT_BACKOFF_SLEEP_MILLIS = 5_000;
    private static final long MAX_BACKOFF_SLEEP_MILLIS = 60_000;

    public static final InFlightLimiter IN_FLIGHT_LIMITER = new InFlightLimiter(1000);

    /**
     * upper limit for write queue
     */
    private static final int MAX_METRICS_IN_WRITE_QUEUE = 1_000_000;


    private final String shardId;
    private final MetricsDao metricsDao;
    private final Executor executor;

    @Nullable
    private Throwable lastLoadException;
    private final CompletableFuture<Void> loadFuture;
    // for manager UI
    @SuppressWarnings("unused")
    @Nullable
    private Instant lastLoadSuccess;

    @Nullable
    private volatile CompletableFuture<?> metricsFindFuture;
    private volatile boolean readOnly = false;
    private volatile boolean stop = false;
    private long backoffSleepMillis = DEFAULT_BACKOFF_SLEEP_MILLIS;

    private final StockpileMetricIdFactory stockpileMetricIdFactory;

    private volatile MemOnlyMetricsCollectionImpl memOnlyMetrics;
    private volatile FileMetricsCollection fileMetrics;
    private final AtomicLong rowsTotal = new AtomicLong();
    private final AtomicLong rowsLoaded = new AtomicLong();
    private final AtomicReference<EShardState> state = new AtomicReference<>(EShardState.NEW);

    // write queue
    private final Tasks writeTasks = new Tasks();
    private final ArrayListLockQueueMem<WriteRequest> writeQueue = new ArrayListLockQueueMem<>();
    private final AtomicInteger writeQueueMetrics = new AtomicInteger();


    public MetabaseShardStorageImpl(String shardId, MetricsDao metricsDao, StockpileMetricIdFactory metricIdFactory, Executor executor) {
        this.shardId = shardId;
        this.metricsDao = metricsDao;
        this.stockpileMetricIdFactory = metricIdFactory;
        this.executor = executor;

        this.memOnlyMetrics = new MemOnlyMetricsCollectionImpl();
        this.fileMetrics = new FileMetricsCollection(shardId, executor);
        this.loadFuture = reload();
    }

    @Override
    public CompletableFuture<Void> reload() {
        var doneFuture = new CompletableFuture<Void>();
        reload(doneFuture);
        return doneFuture;
    }

    private void reload(CompletableFuture<Void> doneFuture) {
        // prevent modifications while reloading collection
        readOnly = (fileMetrics != null);

        IN_FLIGHT_LIMITER.run(() -> {
            if (stop) {
                doneFuture.completeExceptionally(shardWasStoppedException());
                return CompletableFuture.completedFuture(null);
            }

            var future = CompletableFutures.safeCall(metricsDao::getMetricCount)
                    .handle((metricsCount, throwable) -> {
                        if (throwable != null) {
                            logger.error("cannot get metrics count for shard {}", shardId, throwable);
                            return doLoad(OptionalLong.empty());
                        } else {
                            return doLoad(OptionalLong.of(metricsCount));
                        }
                    })
                    .thenCompose(f -> f)
                    .handle((ignore, e) -> {
                        if (e != null) {
                            onLoadError(e, doneFuture);
                        } else {
                            lastLoadSuccess = Instant.now();
                            doneFuture.complete(null);
                        }
                        return null;
                    });
            metricsFindFuture = future;
            return future;
        });
    }

    private CompletableFuture<?> doLoad(OptionalLong metricsCount) {
        // do not try to load data if we already know that the shard is empty
        if (metricsCount.isPresent() && metricsCount.getAsLong() == 0) {
            rowsTotal.set(0);
            rowsLoaded.set(0);

            var oldFileMetrics = this.fileMetrics;
            this.fileMetrics = new FileMetricsCollection("shard " + shardId, executor);
            this.memOnlyMetrics = new MemOnlyMetricsCollectionImpl();
            oldFileMetrics.close();
            readOnly = false;

            state.set(EShardState.READY);
            if (stop) {
                fileMetrics.close();
            }

            return CompletableFuture.completedFuture(null);
        }

        state.set(EShardState.LOADING);
        return findMetrics(metricsCount)
                .thenAccept(metrics -> {
                    backoffSleepMillis = DEFAULT_BACKOFF_SLEEP_MILLIS;
                    state.set(EShardState.INDEXING);

                    var countByShardId = new Int2IntOpenHashMap(64);
                    for (int index = 0; index < metrics.size(); index++) {
                        countByShardId.addTo(metrics.getShardId(index), 1);
                    }
                    stockpileMetricIdFactory.updateStats(countByShardId);

                    FileMetricsCollection oldFileMetrics = this.fileMetrics;
                    this.fileMetrics = new FileMetricsCollection("shard " + shardId, executor, metrics);
                    this.memOnlyMetrics = new MemOnlyMetricsCollectionImpl();
                    oldFileMetrics.close();
                    readOnly = false;

                    // check write queue because we could have writes while shard was in read only state
                    if (writeTasks.addTask()) {
                        ActorRunnerImpl.schedule(executor, this::doWrite);
                    }

                    state.set(EShardState.READY);
                    if (stop) {
                        fileMetrics.close();
                    }
                });
    }

    private CompletableFuture<CoremonMetricArray> findMetrics(OptionalLong metricsCount) {
        rowsTotal.set(metricsCount.orElse(0));
        rowsLoaded.set(0);

        CoremonMetricArray metrics = new CoremonMetricArray(Math.toIntExact(metricsCount.orElse(10_000L)));
        Consumer<CoremonMetricArray> consumer = (chunk) -> {
            metrics.addAll(chunk);
            rowsLoaded.addAndGet(chunk.size());
        };

        return metricsDao.findMetrics(consumer, metricsCount)
                .thenApplyAsync(rowsRead -> {
                    if (metricsCount.isPresent() && rowsLoaded.get() < metricsCount.getAsLong()) {
                        var msg = String.format("loaded %d metrics which is less than estimated %d", rowsTotal.get(), metricsCount.getAsLong());
                        throw new RuntimeException(msg);
                    }

                    double percentOfWastedSpace = ((double) (metrics.capacity() - metrics.size()) * 100) / metrics.capacity();
                    if (percentOfWastedSpace > 10.0) {
                        metrics.shrinkToFit();
                    }

                    if (stop) {
                        throw shardWasStoppedException();
                    }

                    return metrics;
                }, executor)
                .exceptionally(throwable -> {
                    metrics.close();
                    throw new RuntimeException(throwable);
                });
    }

    private void onLoadError(Throwable throwable, CompletableFuture<Void> doneFuture) {
        logger.error("Shard {} load failed", shardId, throwable);
        lastLoadException = throwable;
        readOnly = false;

        if (!stop) {
            long jitterMillis = ThreadLocalRandom.current().nextLong(DEFAULT_BACKOFF_SLEEP_MILLIS);
            long sleepMillis = this.backoffSleepMillis + jitterMillis;

            CompletableFuture.delayedExecutor(sleepMillis, TimeUnit.MILLISECONDS, executor).execute(() -> {
                if (!stop) {
                    reload(doneFuture);
                } else {
                    doneFuture.completeExceptionally(shardWasStoppedException());
                }
            });

            // exponential increase delay time with base 1.5, upto 1 minute
            this.backoffSleepMillis = Math.min(this.backoffSleepMillis + this.backoffSleepMillis / 2, MAX_BACKOFF_SLEEP_MILLIS);
        } else {
            doneFuture.completeExceptionally(shardWasStoppedException());
        }
    }

    @Override
    // takes ownership of newMetrics
    public CompletableFuture<Void> write(CoremonMetricArray newMetrics) {
        return writeInBackground(new WriteRequest(newMetrics));
    }

    @Override
    public CompletableFuture<Void> write(
        Collection<CoremonMetric> updatedMetrics,
        Collection<? extends MetricMeta> newMetrics,
        Multimap<String, ? extends MetricMeta> newAggregates)
    {
        try {
            WriteRequest writeRequest = createWriteRequest(updatedMetrics, newMetrics, newAggregates);
            if (writeRequest == null) {
                return CompletableFuture.completedFuture(null);
            }

            return writeInBackground(writeRequest);
        } catch (Throwable e) {
            return CompletableFuture.failedFuture(e);
        }
    }

    private CompletableFuture<Void> writeInBackground(WriteRequest writeRequest) {
        // not precise limit
        if (writeQueueMetrics.get() > MAX_METRICS_IN_WRITE_QUEUE) {
            writeRequest.completeExceptionally(new UrlStatusTypeException(UrlStatusType.IPC_QUEUE_OVERFLOW));
            return writeRequest;
        }

        if (!stop) {
            writeQueueMetrics.addAndGet(writeRequest.size());
            writeQueue.enqueue(writeRequest);

            if (!readOnly && writeTasks.addTask()) {
                ActorRunnerImpl.schedule(executor, this::doWrite);
            }
        } else {
            writeRequest.release();
            writeRequest.completeExceptionally(new RuntimeException("shard " + shardId + " was stopped"));
        }

        return writeRequest;
    }

    private void doWrite() {
        while (writeTasks.fetchTask()) {
            try {
                List<WriteRequest> writeRequests = writeQueue.dequeueAll();
                if (writeRequests.isEmpty()) {
                    // no requests were submitted
                    continue;
                }

                final int totalCount = writeRequests.stream()
                    .mapToInt(WriteRequest::size)
                    .sum();

                writeQueueMetrics.addAndGet(-totalCount);

                // keep reference to metrics collections to avoid collections changing in the process
                // of inserting new metrics in YDB
                final FileMetricsCollection fileMetrics = this.fileMetrics;

                // combine all write requests into single metrics array
                // and perform their deduplication
                final CoremonMetricArray metrics = combine(writeRequests, fileMetrics, totalCount);

                try {
                    metricsDao.replaceMetrics(metrics)
                        .thenCompose(newMetrics ->  {
                            var upReq = fileMetrics.putAll(newMetrics);
                            newMetrics.closeSilent();
                            return upReq;
                        })
                        .whenComplete((aVoid, throwable) -> {
                            metrics.closeSilent();

                            if (writeTasks.checkTask()) {
                                ActorRunnerImpl.schedule(executor, this::doWrite);
                            }

                            for (WriteRequest writeRequest : writeRequests) {
                                if (throwable != null) {
                                    writeRequest.completeExceptionally(throwable);
                                } else {
                                    writeRequest.complete(null);
                                }
                            }
                        });
                } catch (Throwable t) {
                    metrics.close();
                    throw new RuntimeException("cannot replace metrics in " + shardId, t);
                }

                // keep only one async operation in flight
                break;
            } catch (Throwable t) {
                logger.error("unhandled exception in doWrite()", t);
            }
        }
    }

    private CoremonMetricArray combine(List<WriteRequest> requests, FileMetricsCollection fileMetrics, int maxSize) {
        ObjectOpenHashSet<Labels> keys = new ObjectOpenHashSet<>(maxSize);
        CoremonMetricArray metrics = new CoremonMetricArray(maxSize);

        for (WriteRequest r : requests) {
            try (CoremonMetricArray m = r.takeMetrics()) {
                for (int i = 0, size = m.size(); i < size; i++) {
                    Labels labels = m.getLabels(i);
                    // add only unique and not yet known metrics
                    if (!keys.add(labels)) {
                        continue;
                    }
                    try (var metric = fileMetrics.getOrNull(labels)) {
                        if (metric == null || metric.getType() != m.getType(i)) {
                            metrics.add(m.getShardId(i), m.getLocalId(i), labels, m.getCreatedAtSeconds(i), m.getType(i));
                        }
                    }
                }
            }
        }
        return metrics;
    }

    public CompletableFuture<List<CoremonMetric>> remove(List<CoremonMetric> metricsToRemove) {
        if (readOnly) {
            return CompletableFuture.failedFuture(new IllegalStateException("shard storage is in readonly state"));
        }

        List<Labels> keys = metricsToRemove.stream()
            .map(CoremonMetric::getLabels)
            .collect(Collectors.toList());

        return metricsDao.deleteMetrics(keys)
            .thenCompose(unit -> fileMetrics.removeAll(keys))
            .thenRun(() -> memOnlyMetrics = new MemOnlyMetricsCollectionImpl())
            .thenApply(ignored -> metricsToRemove);
    }

    public boolean isLoaded() {
        return loadFuture.isDone();
    }

    public long getEstimatedRowsTotal() {
        return rowsTotal.get();
    }

    public long getRowsLoaded() {
        return rowsLoaded.get();
    }

    public void awaitLoadComplete() {
        loadFuture.join();
    }

    public CompletableFuture<Void> getLoadFuture() {
        return loadFuture;
    }

    @Nullable
    public Throwable getLastLoadException() {
        return lastLoadException;
    }

    public FileMetricsCollection getFileMetrics() {
        return fileMetrics;
    }

    @Override
    public MemOnlyMetricsCollectionImpl getMemOnlyMetrics() {
        return memOnlyMetrics;
    }

    public StockpileMetricIdFactory getMetricIdFactory() {
        return stockpileMetricIdFactory;
    }

    public long getWriteQueueMetrics() {
        return writeQueueMetrics.get();
    }

    public long getWriteQueueMemSize() {
        return writeQueue.memorySizeIncludingSelf();
    }

    @Nullable
    private WriteRequest createWriteRequest(
        Collection<CoremonMetric> updatedMetrics,
        Collection<? extends MetricMeta> newMetrics,
        Multimap<String, ? extends MetricMeta> newAggregates)
    {
        final int expectedUpdateSize = updatedMetrics.size() + newMetrics.size() + newAggregates.size();
        if (expectedUpdateSize == 0) {
            return null;
        }

        CoremonMetricArray metrics = new CoremonMetricArray(expectedUpdateSize);
        try {
            // (1) add updated metrics
            for (CoremonMetric metric : updatedMetrics) {
                metrics.add(metric);
            }

            // (2) add still unknown direct metrics
            for (MetricMeta metricData : newMetrics) {
                Labels labels = metricData.getLabels();

                Label hostLabel = labels.findByKey(LabelKeys.HOST);
                String host = (hostLabel == null ? null : hostLabel.getValue());
                var stockpileShard = stockpileMetricIdFactory.forHost(host);

                var metricId = stockpileShard.metricId(labels);
                metrics.add(
                    metricId.getShardId(),
                    metricId.getLocalId(),
                    labels,
                    InstantUtils.currentTimeSeconds(),
                    metricData.getType());
            }

            // (3) add still unknown aggregate metrics
            for (var entry : newAggregates.asMap().entrySet()) {
                var stockpileShard = stockpileMetricIdFactory.forHost(entry.getKey());
                for (MetricMeta metricPoint : entry.getValue()) {
                    Labels labels = metricPoint.getLabels();
                    metrics.add(
                        stockpileShard.shardId(labels),
                        StockpileLocalId.random(),
                        labels,
                        InstantUtils.currentTimeSeconds(),
                        metricPoint.getType());
                }
            }

            return new WriteRequest(metrics);
        } catch (Throwable t) {
            metrics.close();
            throw new RuntimeException("cannot create write request", t);
        }
    }

    public void stop() {
        stop = true;
        fileMetrics.close();

        var metricsFindFuture = this.metricsFindFuture;
        if (metricsFindFuture != null) {
            metricsFindFuture.cancel(false);
        }

        var error = shardWasStoppedException();
        ArrayList<WriteRequest> writeRequests = writeQueue.dequeueAll();
        for (WriteRequest writeRequest : writeRequests) {
            writeRequest.release();
            writeRequest.completeExceptionally(error);
        }
    }

    public EShardState getState() {
        return state.get();
    }

    private ShardIsNotLocalException shardWasStoppedException() {
        return new ShardIsNotLocalException("shard " + shardId + " was stopped");
    }

    /**
     * WRITE REQUEST
     */
    private static final class WriteRequest extends CompletableFuture<Void> implements MemMeasurable {
        private static final long SELF_SIZE = MemoryCounter.objectSelfSizeLayout(WriteRequest.class);

        @Nullable
        private CoremonMetricArray metrics;

        // takes ownership of newMetrics
        WriteRequest(CoremonMetricArray metrics) {
            this.metrics = metrics;
        }

        @Override
        public String toString() {
            return "WriteRequest{" + size() + " metrics}";
        }

        public int size() {
            return metrics == null ? 0 : metrics.size();
        }

        public void release() {
            if (metrics != null) {
                metrics.close();
                metrics = null;
            }
        }

        public CoremonMetricArray takeMetrics() {
            CoremonMetricArray refCopy = metrics;
            metrics = null;
            return refCopy;
        }

        @Override
        public boolean complete(Void value) {
            release();
            return super.complete(value);
        }

        @Override
        public boolean completeExceptionally(Throwable t) {
            release();
            return super.completeExceptionally(t);
        }

        @Override
        public long memorySizeIncludingSelf() {
            return SELF_SIZE + (metrics == null ? 0 : metrics.memorySizeIncludingSelf());
        }
    }
}
