package ru.yandex.webmaster3.viewer.http.concurrency;

import org.joda.time.Duration;
import org.springframework.beans.factory.annotation.Required;

import ru.yandex.webmaster3.core.concurrency.AsyncTask;
import ru.yandex.webmaster3.core.solomon.SolomonPushMetricsService;
import ru.yandex.webmaster3.core.solomon.metric.SolomonCounter;
import ru.yandex.webmaster3.core.solomon.metric.SolomonGauge;
import ru.yandex.webmaster3.core.solomon.metric.SolomonKey;
import ru.yandex.webmaster3.core.solomon.metric.SolomonMetricRegistry;
import ru.yandex.webmaster3.core.util.ContextTracker;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;

/**
 * @author avhaliullin
 */
public class AsyncPool {
    private static final String METRIC_LABEL_NAME = "pool-cpu";

    private final ThreadLocal<Long> computationsCounter = new ThreadLocal<>();
    private ForkJoinPool pool;
    private int threadsCount;
    private SolomonCounter cpuUsage;
    private SolomonGauge<Long> cpuTotal;
    private SolomonMetricRegistry solomonMetricRegistry;

    public void init() {
        pool = new ForkJoinPool(threadsCount);
        SolomonKey commonKey = SolomonKey.create(SolomonKey.LABEL_INDICATOR, "async-pool-stat");
        cpuTotal = solomonMetricRegistry.createGauge(commonKey.withLabel(METRIC_LABEL_NAME, "total"));
        cpuUsage = solomonMetricRegistry.createSimpleCounter(commonKey.withLabel(METRIC_LABEL_NAME, "usage"), 0.001d); //считаем мс, отгружаем с
        cpuTotal.set((long)(threadsCount * SolomonPushMetricsService.RESOLUTION_SECONDS));
    }

    public <T> AsyncTask<T> submit(Supplier<T> task, ContextTracker contextTracker) {
        ForkJoinTask<T> fjTask = pool.submit(new TrackingForkJoinTask<>(task, contextTracker.dumpContext()));
        return new AsyncTaskImpl<>(fjTask, contextTracker);
    }

    public <T> AsyncTask<T> fork(Supplier<T> task, ContextTracker contextTracker) {
        ForkJoinTask<T> fjTask = new TrackingForkJoinTask<>(task, contextTracker.dumpContext());
        fjTask.fork();
        return new AsyncTaskImpl(fjTask, contextTracker);
    }

    private void continueComputation() {
        computationsCounter.set(System.currentTimeMillis());
    }

    private void pauseComputation() {
        Long startedAt = computationsCounter.get();
        if (startedAt != null) {
            long duration = System.currentTimeMillis() - startedAt;
            cpuUsage.add(duration);
            computationsCounter.remove();
        }
    }

    private class AsyncTaskImpl<T> implements AsyncTask<T> {
        private final ForkJoinTask<T> fjTask;
        private final ContextTracker contextTracker;

        public AsyncTaskImpl(ForkJoinTask<T> fjTask, ContextTracker contextTracker) {
            this.fjTask = fjTask;
            this.contextTracker = contextTracker;
        }

        @Override
        public T join() {
            pauseComputation();
            ContextTracker.Context context = contextTracker.dumpContext();
            context.clearContext();
            try {
                return fjTask.join();
            } finally {
                continueComputation();
                context.restoreContext();
            }
        }

        @Override
        public T get(Duration timeout) throws InterruptedException, ExecutionException, TimeoutException {
            return fjTask.get(timeout.getMillis(), TimeUnit.MILLISECONDS);
        }

        @Override
        public void cancel() {
            fjTask.cancel(true);
        }
    }

    private class TrackingForkJoinTask<T> extends ForkJoinTask<T> {
        private final Supplier<T> task;
        private final ContextTracker.Context context;
        private volatile T result;

        public TrackingForkJoinTask(Supplier<T> task, ContextTracker.Context context) {
            this.task = task;
            this.context = context;
        }

        @Override
        public T getRawResult() {
            return result;
        }

        @Override
        protected void setRawResult(T value) {
            this.result = value;
        }

        @Override
        protected boolean exec() {
            try {
                continueComputation();
                context.restoreContext();
                result = task.get();
                return true;
            } finally {
                pauseComputation();
                context.clearContext();
            }
        }


    }

    @Required
    public void setThreadsCount(int threadsCount) {
        this.threadsCount = threadsCount;
    }

    @Required
    public void setSolomonMetricRegistry(SolomonMetricRegistry solomonMetricRegistry) {
        this.solomonMetricRegistry = solomonMetricRegistry;
    }
}
