package ru.yandex.webmaster3.worker.queue;

import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import ru.yandex.webmaster3.core.worker.task.WorkerTaskType;
import ru.yandex.webmaster3.worker.Task;
import ru.yandex.webmaster3.worker.TaskRegistry;

import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

/**
 * @author tsyplyaev
 */
public class TaskScheduler {

    private final TaskRegistry taskRegistry;
    private final TaskQueueMetrics taskQueueMetrics;

    private Random random = new Random();
    private Map<WorkerTaskType, AtomicBoolean> taskPauseState = new HashMap<>();

    @Autowired
    public TaskScheduler(TaskRegistry taskRegistry, TaskQueueMetrics taskQueueMetrics) {
        this.taskRegistry = taskRegistry;
        this.taskQueueMetrics = taskQueueMetrics;
    }

    public void init() {
        this.taskPauseState = Collections.unmodifiableMap(
                taskRegistry.getTaskRegistryMap().keySet().stream()
                        .collect(Collectors.toMap(k -> k, v -> new AtomicBoolean(false)))
        );
    }

    WorkerTaskType pollTaskType() {
        List<CandidateTask> candidates = taskRegistry.getTaskRegistryMap().entrySet().stream()
                .filter(e -> !isPaused(e.getKey()))
                .map(e -> new CandidateTask(e.getKey(), e.getValue(), taskQueueMetrics.getTaskStatistics(e.getKey())))
                .collect(Collectors.toList());

        if (candidates.isEmpty()) {
            return null;
        }

        List<Pair<WorkerTaskType, Float>> candidateMetrics = candidates.stream()
                .filter(ut -> ut.taskStatistics.getEnqueueed() > 0)
                .map(ut -> Pair.of(ut.type, ut.task.getPriority(ut.taskStatistics)))
                .filter(p -> p.getValue() > 0.0F)
                .collect(Collectors.toList());

        if (candidateMetrics.isEmpty()) {
            return null;
        }

        if (candidateMetrics.size() == 1) {
            return candidateMetrics.get(0).getKey();
        }

        float total = (float) candidateMetrics.stream().mapToDouble(Pair::getValue).sum();

        float boundary = random.nextFloat() * total;
        float accumulator = 0;
        for (int i = 0; i < candidateMetrics.size() - 1; i++) {
            Pair<WorkerTaskType, Float> pair = candidateMetrics.get(i);
            accumulator += pair.getValue();
            if (accumulator > boundary) {
                return pair.getKey();
            }
        }
        return candidateMetrics.get(candidateMetrics.size() - 1).getKey();
    }

    void pause(WorkerTaskType type) {
        taskPauseState.get(type).set(true);
    }

    void pause() {
        taskPauseState.values().forEach(p -> p.set(true));
    }

    void resume(WorkerTaskType type) {
        taskPauseState.get(type).set(false);
    }

    void resume() {
        taskPauseState.values().forEach(p -> p.set(false));
    }

    public boolean isPaused(WorkerTaskType taskType) {
        return taskPauseState.get(taskType).get();
    }

    public Set<WorkerTaskType> getPausedTasks() {
        return taskPauseState.entrySet().stream()
                .filter(e -> e.getValue().get())
                .map(Map.Entry::getKey)
                .collect(Collectors.toSet());
    }

    private static class CandidateTask {
        private final WorkerTaskType type;
        private final Task task;
        private final TaskQueueMetrics.TaskStatistics taskStatistics;

        public CandidateTask(WorkerTaskType type, Task task,
                TaskQueueMetrics.TaskStatistics taskStatistics)
        {
            this.type = type;
            this.task = task;
            this.taskStatistics = taskStatistics;
        }
    }

}
