package ru.yandex.webmaster3.core.util.concurrent.graph;

import java.io.Closeable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.webmaster3.core.util.concurrent.ExhaustableQueue;

/**
 * Граф обработки данных. Процессы связаны блокирующими конечными очередями.
 * Жизненный цикл выглядит следующим образом:
 * <ol>
 * <li>{@link #start()} - поднимаются все потоки, граф готов к обработке данных</li>
 * <li>Обработка - на этом этапе можно вызывать {@link #put(In)}, посылая на обработку данные</li>
 * <li>{@link #doneWritingAndAwaitTermination()} - входная очередь помечается как "исчерпанная", метод блокируется,
 * пока все потоки не закончат разгребать свои очереди.</li>
 * </ol>
 *
 * {@link #terminateAbruptly()} позволяет досрочно грохнуть процесс.
 *
 * @author avhaliullin
 */
/*
 * TODO: awaitTermination слабо отражает семантику метода. Возможно, стоит переименовать, или даже разбить на два метода
 * TODO: Было бы удобно исполнять это все внутри try-with-resources, хоть может быть и странно ждать заверешения внутри close()
 * TODO: Поработать над диагностикой, чтобы можно было выявлять узкие места процессов
 */
public class GraphExecution<In> implements Closeable {
    private static final Logger log = LoggerFactory.getLogger(GraphExecution.class);

    private final String baseName;
    private final GraphOutQueue<In> inQueue;
    private final List<Thread> threads;
    private final List<Process> processes;
    private final Runnable onDoneWriting;
    private final Runnable onTerminate;

    private AtomicBoolean start = new AtomicBoolean(false);
    private AtomicBoolean finished = new AtomicBoolean(false);
    private AtomicReference<Exception> exception = new AtomicReference<>(null);

    GraphExecution(String baseName, List<Process> processes, GraphOutQueue<In> inQueue, Runnable onDoneWriting, Runnable onTerminate) {
        this.processes = new ArrayList<>(processes);
        this.baseName = baseName;
        this.inQueue = inQueue;
        this.onDoneWriting = onDoneWriting;
        this.onTerminate = onTerminate;
        threads = new ArrayList<>();
        for (Process process : processes) {
            if (process.concurrency <= 1) {
                threads.add(makeThread(process, null));
            } else {
                for (int i = 0; i < process.concurrency; i++) {
                    threads.add(makeThread(process, i));
                }
            }
        }
    }

    public void put(In e) throws InterruptedException {
        inQueue.put(e);
    }

    static class Process {
        private final ExhaustableQueue queue;
        private final BlockingBatchConsumer consumer;
        private final int batchLimit;
        private final int concurrency;
        private final String name;
        private final Runnable doneWritingOutQueues;
        private final boolean forceFullBatch;

        public Process(ExhaustableQueue<?> queue, BlockingBatchConsumer<?> consumer, int batchLimit,
                       int concurrency, String name, Runnable doneWritingOutQueues, boolean forceFullBatch) {
            this.queue = queue;
            this.consumer = consumer;
            this.batchLimit = batchLimit;
            this.concurrency = concurrency;
            this.name = name;
            this.doneWritingOutQueues = doneWritingOutQueues;
            this.forceFullBatch = forceFullBatch;
        }
    }

    public boolean isRunning() {
        return start.get() && !finished.get();
    }

    public void start() {
        if (start.compareAndSet(false, true)) {
            for (Thread t : threads) {
                t.start();
            }
        }
        log.info("Started execution {}", baseName);
    }

    public void doneWritingAndAwaitTermination() throws InterruptedException, ExecutionException {
        onDoneWriting.run();
        for (Thread thread : threads) {
            thread.join();
        }
        log.info("Execution finished");
        finished.set(true);
        Exception e = exception.get();
        if (e != null) {
            throw new ExecutionException(e);
        }
    }

    public void terminateAbruptly() {
        if (finished.compareAndSet(false, true)) {
            for (Thread thread : threads) {
                thread.interrupt();
            }
            onTerminate.run();
            for (Process process : processes) {
                process.queue.terminate();
            }
        }
    }

    private void onFail(Exception e) {
        log.error("Reported graph execution fail", e);
        exception.compareAndSet(null, e);
        terminateAbruptly();
    }

    @Override
    public void close() {
        terminateAbruptly();
    }

    private Thread makeThread(Process process, Integer index) {
        Thread result = new Thread(new GraphThread(process));
        result.setDaemon(true);
        result.setName(baseName + "-" + process.name + (index == null ? "" : "-" + index));
        return result;
    }

    private class GraphThread implements Runnable {
        private final Process process;
        private final List buffer;
        private long processed = 0L;

        public GraphThread(Process process) {
            this.process = process;
            this.buffer = new ArrayList<>(process.batchLimit);
        }

        @Override
        public void run() {
            int queueCapacity = process.queue.capacity();
            try {
                while (!process.queue.isExhausted()) {
                    if (Thread.interrupted()) {
                        throw new InterruptedException();
                    }
                    do {
                        process.queue.pullBatch(buffer::add, process.batchLimit - buffer.size());
                    } while (process.forceFullBatch &&
                            buffer.size() < process.batchLimit &&
                            !process.queue.isExhausted());
                    processed += buffer.size();
                    int queueSize = process.queue.size();
                    log.trace("Pulled {} records, totally processed {}, queue size {}/{}", buffer.size(), processed, queueSize, queueCapacity);
                    if (!buffer.isEmpty()) {
                        process.consumer.accept(buffer);
                    }
                    buffer.clear();
                }
                process.doneWritingOutQueues.run();
            } catch (Exception e) {
                onFail(e);
            }
        }
    }
}
