package ru.yandex.webmaster3.core.util.concurrent.graph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.function.ToIntFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.google.common.base.Preconditions;

import ru.yandex.webmaster3.core.util.W3CollectionUtils;
import ru.yandex.webmaster3.core.util.concurrent.ExhaustableQueue;
import ru.yandex.webmaster3.core.util.functional.Functions;

/**
 * Builder для графа обработки данных. Новый узел обработки добавляется через один из методов process().
 * {@link Process#getInput() Process.getInput()} возвращает дескриптор входной очереди узла. Этот дескриптор можно передавать
 * в другие вызовы process(), чтобы другие узлы писали на вход этому.
 * В конце вызываем {@link #build(QueueDescriptor) build()} , указав дескриптор очереди, которая будет входной для всего графа.
 *
 * @author avhaliullin
 */
public class GraphExecutionBuilder {
    private final String baseName;
    private final IdentityHashMap<QueueDescriptor<?>, InQueueImpl<?>> inQueues = new IdentityHashMap<>();
    private final List<Process<?>> processesList = new ArrayList<>();

    public GraphExecutionBuilder(String baseName) {
        this.baseName = baseName;
    }

    /**
     * Получить готовый GraphExecution
     *
     * @param graphIn дескриптор входной очереди всего графа
     * @return готовый GraphExecution
     */
    public <T> GraphExecution<T> build(QueueDescriptor<T> graphIn) {
        List<GraphExecution.Process> processes = new ArrayList<>();
        for (Process<?> process : processesList) {
            Queue<?> queue = process.inQueue;
            if (process.name == null) {
                process.name = "process-" + ThreadLocalRandom.current().nextInt(0, Integer.MAX_VALUE);
            }
            int writers = 0;
            for (Process writerProcess : queue.writers) {
                writers += writerProcess.concurrency * writerProcess.shards;
            }
            if (queue == graphIn) {
                writers++;
            }
            if (writers == 0) {
                throw new IllegalStateException("QueueDescriptor " + process.name + " have no writers");
            }
            InQueueImpl inQueue;
            if (process.shards > 1) {
                List<ExhaustableQueue<?>> shards = new ArrayList<>();
                for (int shard = 0; shard < process.shards; shard++) {
                    shards.add(new ExhaustableQueue((int) (process.batchLimit * writers * 1.5), writers));
                }
                inQueue = new ShardedInQueueImpl(shards, process.hashF);
            } else {
                inQueue = new SimpleInQueueImpl(new ExhaustableQueue((int) (process.batchLimit * writers * 1.5), writers));
            }
            inQueues.put(queue, inQueue);
        }
        for (Process<?> process : processesList) {
            List<InQueueImpl> toCloseOnDone = queueImpls(Arrays.asList(process.outQueues));
            Runnable doneWritingOutQueues = () -> toCloseOnDone.forEach(InQueueImpl::doneWriting);
            InQueueImpl<?> inQ = findRegisteredInQueue(process.inQueue);
            if (inQ instanceof ShardedInQueueImpl<?>) {
                ShardedInQueueImpl<?> shardedInQ = (ShardedInQueueImpl) inQ;
                for (int shardId = 0; shardId < shardedInQ.shards.size(); shardId++) {
                    ExhaustableQueue<?> inQShard = shardedInQ.shards.get(shardId);
                    GraphExecution.Process processInst = new GraphExecution.Process(
                            inQShard,
                            process.instantiate(),
                            process.batchLimit,
                            process.concurrency,
                            process.name + "-s" + shardId,
                            doneWritingOutQueues,
                            process.forceFullBatch);
                    processes.add(processInst);
                }
            } else if (inQ instanceof SimpleInQueueImpl<?>) {
                SimpleInQueueImpl<?> simpleInQ = (SimpleInQueueImpl) inQ;
                GraphExecution.Process processInst = new GraphExecution.Process(
                        simpleInQ.delegate,
                        process.instantiate(),
                        process.batchLimit,
                        process.concurrency,
                        process.name,
                        doneWritingOutQueues,
                        process.forceFullBatch);
                processes.add(processInst);
            } else {
                throw new RuntimeException("Unknown InQueue implementation " + inQ.getClass() + " " + inQ);
            }
        }
        List<InQueueImpl> toCloseOnDone = queueImpls(Collections.singleton(graphIn));

        return new GraphExecution<T>(baseName, processes, getOutQueue(graphIn),
                () -> toCloseOnDone.forEach(InQueueImpl::doneWriting),
                () -> toCloseOnDone.forEach(InQueueImpl::terminate)
        );
    }

    public static abstract class QueueDescriptor<T> {
        protected abstract void addWriter(Process writer);

        protected abstract Stream<Queue<T>> simpleQueues();
    }

    public static class Queue<T> extends QueueDescriptor<T> {
        private List<Process> writers = new ArrayList<>();

        protected void addWriter(Process writer) {
            writers.add(writer);
        }

        @Override
        protected Stream<Queue<T>> simpleQueues() {
            return Stream.of(this);
        }
    }

    static class MultiplexedQueueDescriptor<T> extends QueueDescriptor<T> {
        private final List<QueueDescriptor<T>> delegates;

        public MultiplexedQueueDescriptor(List<QueueDescriptor<T>> delegates) {
            this.delegates = delegates;
        }

        protected void addWriter(Process writer) {
            delegates.forEach(q -> q.addWriter(writer));
        }

        @Override
        protected Stream<Queue<T>> simpleQueues() {
            return delegates.stream().flatMap(QueueDescriptor::simpleQueues);
        }

    }

    private static <T> GraphOutQueue<T> getOutQueue(InQueueImpl<T> inQueueImpl) {
        if (inQueueImpl instanceof SimpleInQueueImpl) {
            SimpleInQueueImpl<T> simpleInQueue = (SimpleInQueueImpl<T>) inQueueImpl;
            return GraphOutQueue.fromExhaustableQueue(simpleInQueue.delegate);
        } else if (inQueueImpl instanceof ShardedInQueueImpl) {
            ShardedInQueueImpl<T> shardedInQueue = (ShardedInQueueImpl<T>) inQueueImpl;
            return new GraphOutQueue.ShardedOutQueue<>(shardedInQueue.shards, shardedInQueue.hashF);
        } else {
            throw new RuntimeException("Unexpected queue type " + inQueueImpl.getClass() + " " + inQueueImpl);
        }
    }

    private <T> GraphOutQueue<T> getOutQueue(QueueDescriptor<T> queue) {
        if (queue == DEV_NULL) {
            return GraphOutQueue.devnull();
        }
        if (queue instanceof MultiplexedQueueDescriptor) {
            MultiplexedQueueDescriptor<T> multiplexedQ = (MultiplexedQueueDescriptor<T>) queue;
            return new GraphOutQueue.MultiplexingOutQueue<>(
                    multiplexedQ.delegates.stream().map(this::getOutQueue).collect(Collectors.toList())
            );
        } else if (queue instanceof Queue) {
            InQueueImpl<T> inQueueImpl = findRegisteredInQueue((Queue) queue);
            return getOutQueue(inQueueImpl);
        } else {
            throw new RuntimeException("Unexpected queue descriptor type " + queue.getClass() + " " + queue);
        }
    }

    private List<InQueueImpl> queueImpls(Collection<QueueDescriptor> queues) {
        return queues.stream()
                .flatMap(qd -> ((QueueDescriptor<Object>) qd).simpleQueues())
                .map(this::findRegisteredInQueue)
                .collect(Collectors.toList());
    }

    private <T> InQueueImpl<T> findRegisteredInQueue(Queue<T> queue) {
        InQueueImpl<T> result = (InQueueImpl<T>) inQueues.get(queue);
        if (result == null) {
            throw new IllegalArgumentException("QueueDescriptor doesn't belong to this builder");
        }
        return result;
    }

    private <T> Queue<T> newQueue() {
        Queue<T> q = new Queue<>();
        return q;
    }

    public abstract class Process<In> {
        private final Queue<In> inQueue;
        private final QueueDescriptor[] outQueues;
        private String name;
        private int concurrency = 1;
        private int batchLimit = 1000;
        private boolean forceFullBatch = false;
        private int shards = 1;
        private ToIntFunction<In> hashF;

        protected Process(Queue<In> inQueue, QueueDescriptor... outQueues) {
            this.inQueue = inQueue;
            this.outQueues = outQueues;
            processesList.add(this);
            for (QueueDescriptor queue : outQueues) {
                queue.addWriter(this);
            }
        }

        /**
         * Имя узла
         */
        public Process<In> name(String name) {
            this.name = name;
            return this;
        }

        /**
         * Обрабатывать данные не по мере поступления, а дожидаться, пока накопится буфер размером в batchLimit,
         * или данные закончатся.
         */
        public Process<In> forceFullBatch() {
            this.forceFullBatch = true;
            return this;
        }

        /**
         * Количество потоков исполнения для этого процесса
         */
        public Process<In> concurrency(int concurrency) {
            this.concurrency = concurrency;
            return this;
        }

        public Process<In> sharded(int shards, ToIntFunction<In> hashF) {
            Preconditions.checkArgument(shards > 0, "Shards count should be >0, found: " + shards);
            Preconditions.checkNotNull(hashF, "No hash function specified");
            this.shards = shards;
            this.hashF = hashF;
            return this;
        }

        /**
         * Максимальный размер пачки данных, которые будут передаваться за один раз в процесс
         */
        public Process<In> batchLimit(int batchLimit) {
            this.batchLimit = batchLimit;
            return this;
        }

        /**
         * @return дескриптор входной очереди этого узла
         */
        public Queue<In> getInput() {
            return inQueue;
        }

        protected abstract BlockingBatchConsumer<In> instantiate();
    }

    public <In> Process<In> process(Supplier<BlockingBatchConsumer<In>> factory) {
        Queue<In> inQueue = newQueue();
        return new Process<In>(inQueue) {
            @Override
            protected BlockingBatchConsumer<In> instantiate() {
                return factory.get();
            }
        };
    }

    public <In, Q1> Process<In> process(QueueDescriptor<Q1> q1, Functions.F1<GraphOutQueue<Q1>, BlockingBatchConsumer<In>> factory) {
        Queue<In> inQueue = newQueue();
        return new Process<In>(inQueue, q1) {
            @Override
            protected BlockingBatchConsumer<In> instantiate() {
                return factory.apply(getOutQueue(q1));
            }
        };
    }

    public <In, Q1, Q2> Process<In> process(QueueDescriptor<Q1> q1, QueueDescriptor<Q2> q2, Functions.F2<GraphOutQueue<Q1>, GraphOutQueue<Q2>, BlockingBatchConsumer<In>> factory) {
        Queue<In> inQueue = newQueue();
        return new Process<In>(inQueue, q1, q2) {
            @Override
            protected BlockingBatchConsumer<In> instantiate() {
                return factory.apply(getOutQueue(q1), getOutQueue(q2));
            }
        };
    }

    public <In, Q1, Q2, Q3> Process<In> process(QueueDescriptor<Q1> q1, QueueDescriptor<Q2> q2, QueueDescriptor<Q3> q3, Functions.F3<GraphOutQueue<Q1>, GraphOutQueue<Q2>, GraphOutQueue<Q3>, BlockingBatchConsumer<In>> factory) {
        Queue<In> inQueue = newQueue();
        return new Process<In>(inQueue, q1, q2, q3) {
            @Override
            protected BlockingBatchConsumer<In> instantiate() {
                return factory.apply(getOutQueue(q1), getOutQueue(q2), getOutQueue(q3));
            }
        };
    }

    public <In, Q1, Q2, Q3, Q4> Process<In> process(QueueDescriptor<Q1> q1, QueueDescriptor<Q2> q2, QueueDescriptor<Q3> q3, QueueDescriptor<Q4> q4, Functions.F4<GraphOutQueue<Q1>, GraphOutQueue<Q2>, GraphOutQueue<Q3>, GraphOutQueue<Q4>, BlockingBatchConsumer<In>> factory) {
        Queue<In> inQueue = newQueue();
        return new Process<In>(inQueue, q1, q2, q3, q4) {
            @Override
            protected BlockingBatchConsumer<In> instantiate() {
                return factory.apply(getOutQueue(q1), getOutQueue(q2), getOutQueue(q3), getOutQueue(q4));
            }
        };
    }

    @SafeVarargs
    public static <T> QueueDescriptor<T> multiplex(QueueDescriptor<T> q1, QueueDescriptor<T>... qs) {
        return new MultiplexedQueueDescriptor<>(W3CollectionUtils.varargToList(q1, qs));
    }

    public static GraphExecutionBuilder newBuilder(String name) {
        return new GraphExecutionBuilder(name);
    }

    public static final Queue<Object> DEV_NULL = new Queue<>();

    public static <T> Queue<T> devnullQueue() {
        return (Queue<T>) DEV_NULL;
    }

    private static abstract class InQueueImpl<In> {
        void doneWriting() {
            foreachDelegate(ExhaustableQueue::doneWriting);
        }

        void terminate() {
            foreachDelegate(ExhaustableQueue::terminate);
        }

        protected abstract void foreachDelegate(Consumer<ExhaustableQueue<In>> consumer);
    }

    private static class SimpleInQueueImpl<In> extends InQueueImpl<In> {
        private final ExhaustableQueue<In> delegate;

        public SimpleInQueueImpl(ExhaustableQueue<In> delegate) {
            this.delegate = delegate;
        }

        @Override
        protected void foreachDelegate(Consumer<ExhaustableQueue<In>> consumer) {
            consumer.accept(delegate);
        }
    }

    private static class ShardedInQueueImpl<In> extends InQueueImpl<In> {
        private final List<ExhaustableQueue<In>> shards;
        private final ToIntFunction<In> hashF;

        public ShardedInQueueImpl(List<ExhaustableQueue<In>> shards, ToIntFunction<In> hashF) {
            this.shards = shards;
            this.hashF = hashF;
        }

        @Override
        protected void foreachDelegate(Consumer<ExhaustableQueue<In>> consumer) {
            shards.forEach(consumer);
        }
    }
}
