package ru.yandex.crypta.lab.job;

import java.net.SocketTimeoutException;
import java.time.Duration;
import java.util.Set;

import org.quartz.JobExecutionContext;
import org.quartz.JobExecutionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bolts.collection.Cf;
import ru.yandex.crypta.clients.sandbox.SandboxClient;
import ru.yandex.crypta.clients.sandbox.entity.Id;
import ru.yandex.crypta.clients.sandbox.entity.Status;
import ru.yandex.crypta.clients.sandbox.entity.Tag;
import ru.yandex.crypta.clients.sandbox.entity.Task;
import ru.yandex.crypta.clients.sandbox.entity.TaskDraft;

public abstract class SandboxJob extends RetryingJob {

    public static final String ENVIRONMENT = "environment";
    public static final String SANDBOX_TASK_ID = "sandbox_task_id";

    // TODO: somehow this has to stay in sync with sandbox api
    // https://a.yandex-team.ru/arc/trunk/arcadia/sandbox/common/types/task.py?rev=3790997
    private static final Set<String> SUCCESS_STATUSES =
            Cf.set("SUCCESS");
    private static final Set<String> EXECUTING_STATUSES =
            Cf.set("ENQUEUING", "ENQUEUED", "ASSIGNED", "PREPARING", "EXECUTING", "FINISHING", "STOPPING");
    private static final Set<String> RESTARTABLE_STATUSES =
            Cf.set("DRAFT", "WAIT", "EXCEPTION", "TIMEOUT", "STOPPED");
    private static final Set<String> UNRECOVERABLE_STATUSES =
            Cf.set("FAILURE", "DELETED", "NO_RES", "EXPIRED");

    protected static final Logger LOG = LoggerFactory.getLogger(SandboxJob.class);
    protected final SandboxClient sandbox;

    protected SandboxJob(SandboxClient sandbox) {
        this.sandbox = sandbox;
    }

    private static boolean isOf(Status status, Set<String> set) {
        return set.contains(status.getValue());
    }

    @Override
    protected ExecuteResult executeSafe(JobExecutionContext context) throws JobExecutionException {
        try {
            var jobData = context.getJobDetail().getJobDataMap();
            if (jobData.containsKey(SANDBOX_TASK_ID)) {
                var id = new Id(jobData.getLong(SANDBOX_TASK_ID));
                var task = sandbox().get(id);
                var status = task.getStatus();
                if (isOf(status, SUCCESS_STATUSES)) {
                    LOG.info("Task {} is complete", task.getId().getValue());
                    return ExecuteResult.SUCCESS;
                } else if (isOf(status, RESTARTABLE_STATUSES)) {
                    sandbox().start(id);
                    decrementRetriesLeft(context);
                    reschedule(context, checkAgainInterval());
                    return ExecuteResult.RECOVERABLE_FAILURE;
                } else if (isOf(status, UNRECOVERABLE_STATUSES)) {
                    // forget about the task and start from the scratch
                    jobData.remove(SANDBOX_TASK_ID);
                    var newTask = createTask(context);
                    startTask(newTask);
                    return ExecuteResult.UNRECOVERABLE_FAILURE;
                } else if (isOf(status, EXECUTING_STATUSES)) {
                    // do not decrement retries
                    LOG.info("Task {} is executing", task.getId().getValue());
                    reschedule(context, checkAgainInterval());
                    return ExecuteResult.CHECK_AGAIN_LATER;
                } else {
                    LOG.error("Unhandled status: [{}]", status);
                    // reschedule to make it last forever so it is visible
                    reschedule(context, checkAgainInterval());
                    return ExecuteResult.CHECK_AGAIN_LATER;
                }
            } else {
                var task = createTask(context);
                startTask(task);
                reschedule(context, checkAgainInterval());
                return ExecuteResult.CHECK_AGAIN_LATER;
            }
        } catch (RuntimeException e) {
            if (e.getCause() instanceof SocketTimeoutException) {
                LOG.warn("Request to sandbox timed out: {}", e.getMessage());
                reschedule(context, checkAgainInterval());
                return ExecuteResult.CHECK_AGAIN_LATER;
            }
            throw e;
        }
    }

    protected Duration checkAgainInterval() {
        return Duration.ofMinutes(10);
    }

    private void startTask(Task task) {
        sandbox().start(task.getId());
        LOG.info("Started sandbox task {}", task.getId().getValue());
    }

    private Task createTask(JobExecutionContext context) {
        TaskDraft taskDraft = taskDraft(context);
        Task task = sandbox().create(taskDraft);
        context.getJobDetail().getJobDataMap().put(SANDBOX_TASK_ID, task.getId().getValue());
        LOG.info("Created sandbox task {}", task.getId().getValue());
        return task;
    }

    protected abstract TaskDraft taskDraft(JobExecutionContext context);

    protected String getEnvironment(JobExecutionContext context) {
        return getArgument(context, ENVIRONMENT);
    }

    protected SandboxClient sandbox() {
        return sandbox;
    }

    protected TaskDraft createTaskDraft() {
        return new TaskDraft();
    }

    protected Tag tag(String name) {
        return new Tag(name);
    }

}
