package ru.yandex.mail.cerberus.worker.executer;

import io.micronaut.context.annotation.Value;
import io.micronaut.context.event.ApplicationEventListener;
import io.micronaut.context.event.StartupEvent;
import reactor.core.scheduler.Schedulers;
import lombok.AllArgsConstructor;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import ru.yandex.mail.micronaut.common.Async;
import ru.yandex.mail.cerberus.IdempotencyKey;
import ru.yandex.mail.micronaut.common.context.ContextManager;
import ru.yandex.mail.micronaut.common.JsonMapper;
import ru.yandex.mail.cerberus.dao.tx.TxManager;
import ru.yandex.mail.micronaut.common.qualifier.Master;
import ru.yandex.mail.cerberus.worker.api.TaskProcessor;
import ru.yandex.mail.cerberus.worker.WorkerConfiguration;
import ru.yandex.mail.cerberus.worker.api.TaskExecutionContext;
import ru.yandex.mail.cerberus.worker.api.exception.TaskTimeoutException;
import ru.yandex.mail.cerberus.dao.task.TaskRepository;
import ru.yandex.mail.cerberus.dao.task.TaskStatus;
import ru.yandex.mail.cerberus.worker.TaskRegistry;
import ru.yandex.mail.cerberus.worker.exception.TaskProcessorNotFoundException;

import javax.annotation.PreDestroy;
import javax.inject.Inject;
import javax.inject.Named;
import javax.inject.Singleton;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static com.ea.async.Async.await;
import static ru.yandex.mail.micronaut.common.Async.done;

@AllArgsConstructor
class WorkerTaskExecutionContext implements TaskExecutionContext {
    private final Instant startTime;
    private final Duration timeout;
    private final ExecutorService workerExecutor;

    @Override
    @SneakyThrows
    public void setInterruptionPoint() {
        if (startTime.plus(timeout).isBefore(Instant.now())) {
            throw new TaskTimeoutException();
        }

        if (workerExecutor.isShutdown()) {
            throw new InterruptedException();
        }
    }

    @Override
    public ExecutorService getExecutor() {
        return workerExecutor;
    }
}

@Slf4j
@Singleton
public class Worker implements ApplicationEventListener<StartupEvent> {
    private static final String TASK_TYPE_MDC_KEY = "taskType";
    private static final String TASK_KEY_MDC_KEY = "taskKey";
    static final String WORKER_EXECUTOR_NAME = "cerberus-worker-executor";

    private final TxManager txManager;
    private final TaskRepository taskRepository;
    private final TaskRegistry taskRegistry;
    private final WorkerConfiguration configuration;
    private final JsonMapper jsonMapper;
    private final ScheduledExecutorService workerExecutor;
    private final String hostname;
    private AtomicInteger executingTasksCount;

    @SneakyThrows
    private static String getError(Throwable exception) {
        try (val strWriter = new StringWriter(100);
             val writer = new PrintWriter(strWriter)) {
            exception.printStackTrace(writer);
            return strWriter.toString();
        }
    }

    private TaskExecutionContext newTaskContext(Duration timeout) {
        return new WorkerTaskExecutionContext(Instant.now(), timeout, workerExecutor);
    }

    private <Context> void tryReschedule(TaskRepository.GenericTaskInfo taskInfo, Optional<Context> newContext) {
        val type = taskInfo.getType();

        taskRegistry.findTaskConfiguration(type).ifPresent(config -> {
            config.getRepetitionRate().ifPresent(repetitionRate -> {
                val newTask = taskRepository.insertTask(IdempotencyKey.random(), type, repetitionRate, config.getTimeout(),
                    taskInfo.getRequestId(), taskInfo.getInitiatorUid(), Optional.empty(), newContext);

                log.info("Next execution scheduled for '{}' task with idempotency key '{}' at {}",
                    type, newTask.getIdempotencyKey(), newTask.getSchedule());
            });
        });
    }

    private <Context> CompletableFuture<Void> finishTask(TaskRepository.GenericTaskInfo taskInfo, TaskStatus status,
                                                         Optional<String> error, Optional<Context> newContext) {
        val key = taskInfo.getIdempotencyKey();

        return txManager.runAsync(() -> {
            try {
                if (!taskRepository.finishTask(key, status, error)) {
                    log.error("Finishing {} task failed: task not found (possibly finished by the recycler)", key);
                } else {
                    log.info("Finish task {} with status {}", key, status);
                    tryReschedule(taskInfo, newContext);
                }
            } catch (Exception e) {
                log.error("Finishing task failed", e);
            }
        });
    }

    private <Context> CompletableFuture<Void> finishTask(TaskRepository.GenericTaskInfo taskInfo, Optional<Context> newContext) {
        return finishTask(taskInfo, TaskStatus.SUCCESS, Optional.empty(), newContext);
    }

    private CompletableFuture<Void> finishTask(TaskRepository.GenericTaskInfo taskInfo, Throwable exception) {
        val status = exception instanceof TaskTimeoutException ? TaskStatus.TIMEOUT : TaskStatus.FAILED;
        return finishTask(taskInfo, status, Optional.of(getError(exception)), taskInfo.getContext());
    }

    private CompletableFuture<Void> executeTask(TaskRepository.GenericTaskInfo taskInfo) {
        try {
            executingTasksCount.incrementAndGet();
            ContextManager.setRequestId(taskInfo.getRequestId());
            ContextManager.setMdcValues(Map.of(
                TASK_TYPE_MDC_KEY, taskInfo.getType().toString(),
                TASK_KEY_MDC_KEY, taskInfo.getIdempotencyKey().toString()
            ));

            val newContext = await(executeTaskImpl(taskInfo));
            return finishTask(taskInfo, newContext);
        } catch (Throwable e) {
            e = Async.unwrap(e);
            log.error("Task processing error", e);
            return finishTask(taskInfo, e);
        } finally {
            executingTasksCount.decrementAndGet();
            ContextManager.resetRequestId();
            ContextManager.resetMdcValues(TASK_TYPE_MDC_KEY, TASK_KEY_MDC_KEY);
        }
    }

    @SuppressWarnings("unchecked")
    private <Context> CompletableFuture<Optional<Context>> executeProcess(TaskProcessor<Context> processor,
                                                                          @SuppressWarnings("rawtypes") Optional context,
                                                                          TaskExecutionContext executionContext) {
        return processor.process(context, executionContext)
            .subscribeOn(Schedulers.fromExecutor(workerExecutor))
            .map(Optional::of)
            .defaultIfEmpty(Optional.empty())
            .toFuture();
    }

    private CompletableFuture<Optional<?>> executeTaskImpl(TaskRepository.GenericTaskInfo taskInfo) {
        val type = taskInfo.getType();

        val taskRecord = taskRegistry.findTaskRecord(type)
            .orElseThrow(() -> new TaskProcessorNotFoundException(type));

        val processor = (TaskProcessor<?>) taskRecord.getProcessor();
        val context =  (Optional<?>) taskInfo.getContext()
            .map(value -> jsonMapper.fromJson(value, processor.contextType()));

        log.info("Start task execution");
        val sample = taskRecord.getMeter().start();

        try {
            val timeout = taskRecord.getConfiguration().getTimeout();
            val newContext = await(executeProcess(processor, context, newTaskContext(timeout)));
            log.info("Task successfully complete");
            return done(newContext);
        } finally {
            sample.stop();
        }
    }

    private void heartbeat() {
        val tasksCount = executingTasksCount.get();
        val freeExecutionSlotsCount = configuration.getSimultaneousTasksCount() - tasksCount;
        log.info("Heartbeat. Tasks in process: {}, free execution slots: {}",
            tasksCount, freeExecutionSlotsCount);

        if (freeExecutionSlotsCount > 0) {
            log.info("Check available tasks");
            txManager.executeAsync(() -> {
                return taskRepository.acquireTasks(freeExecutionSlotsCount, hostname);
            })
            .whenComplete((tasksInfo, e) -> {
                if (e != null) {
                    log.error("Failed to acquire new tasks", e);
                } else {
                    if (tasksInfo.isEmpty()) {
                        log.info("No tasks ready for execution");
                    }

                    tasksInfo.forEach(info -> {
                        try {
                            workerExecutor.submit(() -> executeTask(info));
                        } catch (Exception ex) {
                            log.error("Can't submit {} task with key {}", info.getType(), info.getIdempotencyKey(), ex);
                        }
                    });
                }

                scheduleNextHeartbeat();
            });
        } else {
            log.info("Insufficient execution slots to execute a task");
            scheduleNextHeartbeat();
        }
    }

    private void scheduleNextHeartbeat() {
        try {
            val rate = configuration.getWorkerRate();
            workerExecutor.schedule(this::heartbeat, rate.toMillis(), TimeUnit.MILLISECONDS);
        } catch (Exception e) {
            log.error("Can't schedule next heartbeat", e);
        }
    }

    @Inject
    public Worker(@Master TxManager txManager, TaskRepository taskRepository, TaskRegistry taskRegistry,
                  WorkerConfiguration configuration, JsonMapper jsonMapper,
                  @Named(WORKER_EXECUTOR_NAME) ExecutorService workerExecutor,
                  @Value("${micronaut.server.hostname:UNKNOWN}") String hostname) {
        this.txManager = txManager;
        this.taskRepository = taskRepository;
        this.taskRegistry = taskRegistry;
        this.configuration = configuration;
        this.jsonMapper = jsonMapper;
        this.workerExecutor = (ScheduledExecutorService) workerExecutor;
        this.hostname = hostname;
        this.executingTasksCount = new AtomicInteger(0);
    }

    @Override
    public void onApplicationEvent(StartupEvent event) {
        log.info("Worker start");
        workerExecutor.submit(this::heartbeat);
    }

    @PreDestroy
    public void stop() {
        workerExecutor.shutdown();
    }
}
