package ru.yandex.concurrent;

import java.util.Iterator;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.BinaryOperator;
import java.util.function.Consumer;
import java.util.function.Supplier;

public class ConcurrentAccumulator<T> implements Consumer<T>, Supplier<T> {
    private final Queue<Node<T>> nodes = new ConcurrentLinkedQueue<>();
    private final BinaryOperator<T> accumulatorFunction;
    private final T zero;
    private final ThreadLocalNode<T> node;

    public ConcurrentAccumulator(
        final BinaryOperator<T> accumulatorFunction,
        final T zero)
    {
        this.accumulatorFunction = accumulatorFunction;
        this.zero = zero;
        node = new ThreadLocalNode<>(nodes, zero);
    }

    @Override
    public void accept(final T value) {
        Node<T> node = this.node.get();
        node.value = accumulatorFunction.apply(node.value, value);
    }

    @Override
    public T get() {
        Iterator<Node<T>> iter = nodes.iterator();
        if (iter.hasNext()) {
            T result = iter.next().value;
            while (iter.hasNext()) {
                result = accumulatorFunction.apply(result, iter.next().value);
            }
            return result;
        } else {
            return zero;
        }
    }

    private static class Node<T> {
        private volatile T value;

        Node(final T value) {
            this.value = value;
        }
    }

    private static class ThreadLocalNode<T> extends ThreadLocal<Node<T>> {
        private final Queue<Node<T>> nodes;
        private final T zero;

        public ThreadLocalNode(final Queue<Node<T>> nodes, final T zero) {
            this.nodes = nodes;
            this.zero = zero;
        }

        @Override
        public Node<T> initialValue() {
            Node<T> result = new Node<>(zero);
            nodes.add(result);
            return result;
        }
    }
}

