package ru.yandex.partner.core.entity.queue.service;

import java.io.IOException;
import java.util.Collection;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import ru.yandex.partner.core.entity.queue.exceptions.TaskQueueException;
import ru.yandex.partner.core.entity.queue.multistate.TaskMultistateGraph;
import ru.yandex.partner.core.entity.queue.repository.TaskRepository;
import ru.yandex.partner.core.multistate.Multistate;
import ru.yandex.partner.core.multistate.queue.TaskMultistate;
import ru.yandex.partner.core.multistate.queue.TaskStateFlag;
import ru.yandex.partner.core.queue.Task;
import ru.yandex.partner.core.queue.TaskData;
import ru.yandex.partner.core.queue.TaskExecutionResult;
import ru.yandex.partner.core.queue.TaskFactory;
import ru.yandex.partner.core.queue.TaskPayload;
import ru.yandex.partner.defaultconfiguration.PartnerLocalDateTime;
import ru.yandex.partner.libs.annotation.PartnerTransactional;

import static ru.yandex.partner.libs.multistate.MultistatePredicates.empty;
import static ru.yandex.partner.libs.multistate.MultistatePredicates.has;

@Service
public class TaskQueueService {

    private final TaskRepository taskRepository;
    private final TaskFacade taskFacade;
    private final TaskMultistateGraph taskMultistateGraph;
    private final ObjectMapper objectMapper;

    private final TaskMultistate stateNew;
    private final TaskMultistate stateWorking;
    private final TaskMultistate stateCanceled;
    private final TaskMultistate stateFinishedWithError;
    private final TaskMultistate stateFinishedWithSuccess;
    private final TaskMultistate stateNeedRestart;

    @Autowired
    public TaskQueueService(TaskRepository taskRepository, TaskFacade taskFacade,
                            TaskMultistateGraph taskMultistateGraph, ObjectMapper objectMapper) {
        this.taskRepository = taskRepository;
        this.taskFacade = taskFacade;
        this.taskMultistateGraph = taskMultistateGraph;
        this.objectMapper = objectMapper;

        this.stateNew = extractSingleState(empty());
        this.stateWorking = extractSingleState(has(TaskStateFlag.WORKING));
        this.stateCanceled = extractSingleState(has(TaskStateFlag.CANCELED));
        this.stateFinishedWithError = extractSingleState(has(TaskStateFlag.FINISHED_WITH_ERROR));
        this.stateFinishedWithSuccess = extractSingleState(has(TaskStateFlag.FINISHED_WITH_SUCCESS));
        this.stateNeedRestart = extractSingleState(has(TaskStateFlag.NEED_RESTART));
    }

    public void enqueue(TaskPayload taskPayload) {
        taskRepository.insertTask(toTaskData(taskPayload));
    }

    public void enqueue(Collection<TaskPayload> taskPayloads) {
        taskRepository.insertTasks(
                taskPayloads.stream()
                        .map(this::toTaskData)
                        .collect(Collectors.toList())
        );
    }

    @PartnerTransactional
    public <T extends Task<?, ?>> Optional<Task<?, ?>> grabTask(Class<T> clazz) {
        return taskRepository.pickFreeTask(clazz).map(taskData -> {
            taskData.incrementTries();
            taskData.setStartDt(PartnerLocalDateTime.now());
            taskData.setMultistate(stateWorking);

            TaskFactory<?> taskFactory = taskFacade.getFactoryForTaskClass(clazz);

            Task<?, ?> task = taskFacade.fromTaskData(taskData, taskFactory.getTaskClass());

            taskRepository.lockTask(task.getTaskData(), task.getEstimatedTime());

            return task;
        });
    }

    public void finishTask(Task<?, ?> task, TaskExecutionResult result) {
        TaskFactory<?> taskFactory = taskFacade.getFactoryForTask(task);
        TaskData taskData = task.getTaskData();

        String serializedResult = null;
        try {
            if (result.getResult() != null) {
                serializedResult = objectMapper.writeValueAsString(result.getResult());
            }
        } catch (IOException exception) {
            throw new TaskQueueException("Failed to serialize task execution result", exception);
        }

        taskData.setGrabbedAt(null);
        taskData.setGrabbedBy(null);
        taskData.setGrabbedUntil(null);

        if (result.isSuccess()) {
            taskData.setMultistate(stateFinishedWithSuccess);
            taskData.setResult(serializedResult);
            taskData.setEndDt(PartnerLocalDateTime.now());
            taskData.setErrorData(null);
            taskData.setLog(null);
        } else {
            taskData.setErrorData(serializedResult);
            taskData.setLog(result.getErrorMessage());

            if (taskData.getTries() < taskFactory.getMaxTries()) {
                taskData.setMultistate(stateNeedRestart);
                if (!taskFactory.getTryAfter().isZero()) {
                    // resetting grabbedUntil!
                    taskData.setGrabbedUntil(PartnerLocalDateTime.now().plus(taskFactory.getTryAfter()));
                }
            } else {
                taskData.setMultistate(stateFinishedWithError);
            }
        }

        taskRepository.updateTask(taskData);
    }

    private TaskMultistate extractSingleState(Predicate<Multistate<TaskStateFlag>> predicate) {
        Set<Multistate<TaskStateFlag>> taskMultistateSet = taskMultistateGraph.getMultistatesForPredicate(predicate);
        if (taskMultistateSet.size() == 1) {
            return (TaskMultistate) taskMultistateSet.iterator().next();
        } else {
            throw new TaskQueueException("Unexpected task multistate graph behaviour. Got multiple states for a " +
                    "predicate");
        }
    }

    private TaskData toTaskData(TaskPayload taskPayload) {
        return TaskData.newBuilder()
                .withId(null)
                .withMultistate(stateNew)
                .withAddDt(null)
                .withStartDt(null)
                .withEndDt(null)
                .withTypeId(taskPayload.getTypeId())
                .withGroupId(taskPayload.getGroupId())
                .withParams(taskPayload.serializeParams())
                .withUserId(0L)
                .withLog(null)
                .withErrorData(null)
                .withResult(null)
                .withTries(0)
                .withGrabbedBy(null)
                .withGrabbedAt(null)
                .withGrabbedUntil(null)
                .build();
    }
}
