package ru.yandex.stockpile.server.shard;

import java.util.Deque;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import java.util.concurrent.SubmissionPublisher;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.misc.actor.Tasks;
import ru.yandex.solomon.config.protobuf.stockpile.EInvalidArchiveStrategy;
import ru.yandex.solomon.memory.layout.MemMeasurable;
import ru.yandex.solomon.memory.layout.MemoryCounter;
import ru.yandex.solomon.util.CloseableUtils;
import ru.yandex.stockpile.api.EProjectId;
import ru.yandex.stockpile.memState.MetricIdAndData;
import ru.yandex.stockpile.server.StockpileProjects;
import ru.yandex.stockpile.server.shard.MergeProcessMetrics.MergeKindMetrics;

import static com.google.common.base.Preconditions.checkArgument;

/**
 * @author Vladimir Gordiychuk
 */
public class MergeMerger extends SubmissionPublisher<MergeTaskResult> implements Flow.Processor<List<MetricIdAndData>, MergeTaskResult>, MemMeasurable {
    private static final Logger logger = LoggerFactory.getLogger(MergeMerger.class);
    private static final long SELF_SIZE = MemoryCounter.objectSelfSizeLayout(MergeMerger.class);
    private static final long TASK_SELF_SIZE = MemoryCounter.objectSelfSizeLayout(Task.class);
    private static final long MAX_IN_FLIGHT = Runtime.getRuntime().availableProcessors();

    private final Deque<CompletableFuture<MergeTaskResult>> work = new ConcurrentLinkedDeque<>();
    private final AtomicLong inFlight = new AtomicLong(0);

    private final int shardId;
    private final Executor executor;
    private final boolean allowDecim;
    private final InvalidArchiveStrategy invalidArchiveStrategy;
    private final long splitDelayMillis;
    private final MergeKindMetrics metrics;
    private final AtomicLong memoryUse = new AtomicLong(SELF_SIZE);
    private final long now;

    private Tasks tasks = new Tasks();
    private AtomicLong capacity = new AtomicLong();
    private AtomicLong mergedMetrics = new AtomicLong();
    private Flow.Subscription subscription;
    private int subscribersCount = 0;
    private volatile boolean complete = false;

    public MergeMerger(ShardThread shardThread, long now, boolean allowDecim, long splitDelayMillis, InvalidArchiveStrategy invalidArchiveStrategy, MergeKindMetrics metrics) {
        this(shardThread.shard.shardId, shardThread.shard.mergeExecutor, now, allowDecim, splitDelayMillis, invalidArchiveStrategy, metrics);
    }

    public MergeMerger(int shardId, Executor executor, long now, boolean allowDecim, long splitDelayMillis, InvalidArchiveStrategy invalidArchiveStrategy, MergeKindMetrics metrics) {
        super(executor, 2048);
        this.shardId = shardId;
        this.executor = executor;
        this.allowDecim = allowDecim;
        this.splitDelayMillis = splitDelayMillis;
        this.invalidArchiveStrategy = invalidArchiveStrategy;
        this.metrics = metrics;
        this.now = now;
    }

    @Override
    public void subscribe(Flow.Subscriber<? super MergeTaskResult> subscriber) {
        subscribersCount++;
        super.subscribe(new SubscriberWrapper(subscriber));
    }

    @Override
    public void onSubscribe(Flow.Subscription subscription) {
        int subscribers = getNumberOfSubscribers();
        checkArgument(subscribers > 0, subscribers);
        this.subscription = subscription;
        this.capacity.set(MAX_IN_FLIGHT * 3);
        requestMore();
    }

    @Override
    public void onNext(List<MetricIdAndData> item) {
        if (item.isEmpty()) {
            subscription.request(1);
            return;
        }

        var task = new Task(item);
        if (StockpileProjects.shouldBeDropped(task.getProject())) {
            CloseableUtils.close(item);
            mergedMetrics.addAndGet(task.source.size());
            subscription.request(1);
            return;
        }

        addUsedMemory(task.memorySizeIncludingSelf());
        var future = CompletableFuture.supplyAsync(task, executor);
        future.whenComplete((r, e) -> {
            inFlight.decrementAndGet();
            scheduleProcessing();
        });
        work.addLast(future);
    }

    public long getMergedMetrics() {
        return mergedMetrics.get();
    }

    private void scheduleProcessing() {
        try {
            if (tasks.addTask()) {
                processMerged();
            }
        } catch (Throwable e) {
            closeExceptionally(e);
        }
    }

    private void processMerged() throws Throwable {
        while (tasks.fetchTask()) {
            submitCompleted();
            requestMore();
            if (work.isEmpty() && complete) {
                close();
            }
        }
    }

    private void submitCompleted() throws Throwable {
        int lag = estimateMaximumLag();
        while (lag < getMaxBufferCapacity()) {
            var future = work.peek();
            if (future == null || !future.isDone()) {
                return;
            }

            work.poll();

            inFlight.incrementAndGet();
            subscription.request(1);
            try {
                var merged = future.join();
                long usage = merged.memorySizeIncludingSelf();
                addUsedMemory(usage * (subscribersCount - 1));
                lag = submit(merged);
                metrics.mergeMaxLag.record(lag);
            } catch (Throwable e) {
                metrics.invalidArchives.inc();
                logger.error("Failed merge at shard " + shardId, e);
                if (invalidArchiveStrategy.strategy == EInvalidArchiveStrategy.DROP) {
                    continue;
                }

                throw new RuntimeException(e);
            }
        }
    }

    private void requestMore() {
        long capacity = this.capacity.get();
        long inFlight = this.inFlight.get();
        long delta = MAX_IN_FLIGHT - inFlight;

        if (delta <= 0) {
            return;
        }

        long min = Math.min(delta, capacity);
        if (min <= 0) {
            return;
        }

        this.inFlight.addAndGet(min);
        this.capacity.addAndGet(-min);
        this.subscription.request(min);
    }

    @Override
    public void onError(Throwable throwable) {
        closeExceptionally(throwable);
    }

    @Override
    public void onComplete() {
        complete = true;
        scheduleProcessing();
    }

    @Override
    public long memorySizeIncludingSelf() {
        return memoryUse.get();
    }

    private void addUsedMemory(long memory) {
        long total = memoryUse.addAndGet(memory);
        if (total < 0) {
            throw new IllegalStateException("Negative count " + total + " memory usage, added: " + memory);
        }
    }

    private class Task implements MemMeasurable, Supplier<MergeTaskResult> {
        private final long createAtNanos = System.nanoTime();
        private final long memoryUsage;
        private final List<MetricIdAndData> source;

        Task(List<MetricIdAndData> source) {
            this.source = source;
            this.memoryUsage = MemoryCounter.listDataSizeWithContent(source);
        }

        @Override
        public long memorySizeIncludingSelf() {
            return TASK_SELF_SIZE + memoryUsage;
        }

        @Override
        public MergeTaskResult get() {
            var task = new MergeTask(shardId, source, now, splitDelayMillis, allowDecim, metrics.getTaskMetrics());
            var result = task.run();
            memoryUse.addAndGet(-this.memorySizeIncludingSelf() + result.memorySizeIncludingSelf());
            metrics.addMergeTime(System.nanoTime() - createAtNanos);
            mergedMetrics.addAndGet(source.size());
            return result;
        }

        private EProjectId getProject() {
            for (int index = source.size() - 1; index >= 0; index--) {
                EProjectId projectId = source.get(index).archive().getOwnerProjectIdOrUnknown();
                if (projectId != EProjectId.UNKNOWN) {
                    return projectId;
                }
            }
            return EProjectId.UNKNOWN;
        }
    }

    private class SubscriberWrapper implements Flow.Subscriber<MergeTaskResult> {
        private final Flow.Subscriber<? super MergeTaskResult> subscriber;

        public SubscriberWrapper(Flow.Subscriber<? super MergeTaskResult> subscriber) {
            this.subscriber = subscriber;
        }

        @Override
        public void onSubscribe(Flow.Subscription subscription) {
            subscriber.onSubscribe(new SubscriptionWrapper(subscription));
        }

        @Override
        public void onNext(MergeTaskResult item) {
            addUsedMemory(-item.memorySizeIncludingSelf());
            subscriber.onNext(item);
        }

        @Override
        public void onError(Throwable throwable) {
            subscriber.onError(throwable);
        }

        @Override
        public void onComplete() {
            subscriber.onComplete();
        }
    }

    private class SubscriptionWrapper implements Flow.Subscription {
        private final Flow.Subscription subscription;

        public SubscriptionWrapper(Flow.Subscription subscription) {
            this.subscription = subscription;
        }

        @Override
        public void request(long n) {
            if (n > 0) {
                subscription.request(n);
                scheduleProcessing();
            }
        }

        @Override
        public void cancel() {
            subscription.cancel();
        }
    }
}
