package ru.yandex.travel.hotels.searcher;

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

public class BatchedTaskQueue {
    public static class Batch {
        private Task.GroupingKey key;
        private List<Task> tasks;

        public Batch(Task.GroupingKey key, List<Task> tasks) {
            this.key = key;
            this.tasks = tasks;
        }

        public Task.GroupingKey getKey() {
            return key;
        }

        public List<Task> getTasks() {
            return tasks;
        }
    }

    private final HashMap<Task.GroupingKey, LinkedList<Task>> map;
    private final AtomicInteger size;
    private final int maxSize;

    public BatchedTaskQueue() {
        this(Integer.MAX_VALUE);
    }

    public BatchedTaskQueue(int maxSize) {
        map = new HashMap<>();
        size = new AtomicInteger(0);
        this.maxSize = maxSize;
    }

    public boolean offer(Task task) {
        synchronized (map) {
            if (size.get() >= this.maxSize) {
                return false;
            }
            map.computeIfAbsent(task.getGroupingKey(), key -> new LinkedList<>()).add(task);
            size.incrementAndGet();
        }
        return true;
    }

    public Batch getBatch(int batchSize) {
        synchronized (map) {
            if (size.get() == 0) {
                return null;
            }
            Map.Entry<Task.GroupingKey, LinkedList<Task>> oldestEntry = map.entrySet().stream()
                    .min(Comparator.comparingLong(entry -> entry.getValue().peek().getCreatedAtNanos())).get();
            if (oldestEntry.getValue().size() <= batchSize) {
                // We are in synchronized block, no need in CAS loops.
                size.set(size.get() - oldestEntry.getValue().size());
                map.remove(oldestEntry.getKey());
                return new Batch(oldestEntry.getKey(), oldestEntry.getValue());
            } else {
                List<Task> result = new ArrayList<>(batchSize);
                for (int i = 0; i < batchSize; i++) {
                    result.add(oldestEntry.getValue().poll());
                    size.decrementAndGet();
                }
                return new Batch(oldestEntry.getKey(), result);
            }
        }
    }

    public List<Task> drain() {
        List<Task> result = new ArrayList<>(size.get());
        synchronized (map) {
            for (LinkedList<Task> queue : map.values()) {
                result.addAll(queue);
            }
            map.clear();
            size.set(0);
        }
        return result;
    }

    public int getSize() {
        return size.get();
    }

    public boolean isEmpty() {
        return size.get() == 0;
    }
}
