package ru.yandex.webmaster3.worker.queue;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.joda.JodaModule;
import com.fasterxml.jackson.module.paramnames.ParameterNamesModule;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Required;
import org.springframework.context.SmartLifecycle;

import ru.yandex.webmaster3.core.http.WebmasterJsonModule;
import ru.yandex.webmaster3.core.worker.task.TaskResult;
import ru.yandex.webmaster3.core.worker.task.WorkerTaskData;
import ru.yandex.webmaster3.core.worker.task.WorkerTaskPriority;
import ru.yandex.webmaster3.core.worker.task.WorkerTaskType;
import ru.yandex.webmaster3.worker.Task;
import ru.yandex.webmaster3.worker.TaskRegistry;

/**
 * @author aherman
 */
public class TaskQueueService implements SmartLifecycle {
    private static final Logger log = LoggerFactory.getLogger(TaskQueueService.class);

    public static final String MDC_TASK_ID_KEY = "taskQueue_taskId";

    private int poolSize = 8;
    private ExecutorService executorService;

    private TaskQueue taskQueue;
    @Autowired
    private TaskRegistry taskRegistry;
    private TaskQueueMetrics taskQueueMetrics;
    private TaskScheduler taskScheduler;
    private volatile boolean isRunning = false;

    private static final ObjectMapper OM = new ObjectMapper()
            .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
            .registerModules(new JodaModule(), new WebmasterJsonModule(false))
            .registerModules(new ParameterNamesModule());

    public void init() {
        executorService = Executors.newFixedThreadPool(
                poolSize,
                new ThreadFactoryBuilder()
                        .setDaemon(true)
                        .setNameFormat("task-worker-%d")
                        .build()
        );
    }

    @Override
    public void start() {
        log.info("Start workers...");
        for (int i = 0; i < poolSize; i++) {
            executorService.execute(new WorkerRunnable());
        }
        log.info("Workers started");
        isRunning = true;
    }

    @Override
    public boolean isAutoStartup() {
        return true;
    }

    @Override
    public void stop(Runnable callback) {
        log.info("Stop workers...");
        taskQueue.beginShutdown();
        executorService.shutdownNow();
        callback.run();
    }

    @Override
    public int getPhase() {
        return Integer.MAX_VALUE;
    }

    @Override
    public boolean isRunning() {
        return isRunning;
    }

    @Override
    public void stop() {
    }

    public void destroy() {
        try {
            taskQueue.shutdown();
        } catch (IOException e) {
            log.error("Unable to shutdown queue", e);
        }
        try {
            executorService.awaitTermination(30, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            log.error("Unable to shutdown executor in time", e);
        }
        log.info("Workers stopped");
    }

    public void clear() {
        taskQueue.clear();
    }

    public void clear(WorkerTaskType type) {
        taskQueue.clear(type);
    }

    public void pause() {
        taskScheduler.pause();
    }

    public void pause(WorkerTaskType type) {
        taskScheduler.pause(type);
    }

    public void resume() {
        taskScheduler.resume();
    }

    public void resume(WorkerTaskType type) {
        taskScheduler.resume(type);
    }

    public boolean enqueueTask(WorkerTaskData taskData) {
        String taskIdStr = taskData.getTaskId().toString();
        try (MDC.MDCCloseable mdcCloseable = MDC.putCloseable(MDC_TASK_ID_KEY, taskIdStr)) {
            WorkerTaskType taskType = taskData.getTaskType();
            Task task = taskRegistry.getTaskRegistryMap().get(taskType);
            if (task == null) {
                log.error("Unknown task type: {} {}", taskType, taskIdStr);
                return false;
            }
            return enqueueTask(taskData, task, taskData.getTaskPriority());
        } catch (IOException e) {
            log.error("Unable to add task: type={}", taskData.getTaskType(), e);
            return false;
        }
    }

    public boolean enqueueTask(WorkerTaskType taskType, String taskIdStr, String taskData,
                               WorkerTaskPriority taskPriority) {
        try (MDC.MDCCloseable mdcCloseable = MDC.putCloseable(MDC_TASK_ID_KEY, taskIdStr)) {
            // log.info("Enqueue task: {}", taskIdStr);
            Task task = taskRegistry.getTaskRegistryMap().get(taskType);
            if (task == null) {
                log.error("Unknown task type: {} {}", taskType, taskIdStr);
                return false;
            }

            WorkerTaskData data = (WorkerTaskData) OM.readValue(taskData, task.getDataClass());

            return enqueueTask(data, task, taskPriority);
        } catch (IOException e) {
            log.error("Unable to add task: type={}", taskType, e);
            return false;
        }
    }

    private boolean enqueueTask(WorkerTaskData taskData, Task task, WorkerTaskPriority taskPriority) throws IOException {
        WorkerTaskType taskType = taskData.getTaskType();
        Integer queueSizeLimit = task.getMaxQueueSize();
        if (queueSizeLimit != null && taskQueueMetrics.getTaskStatistics(taskType).getEnqueueed() >= queueSizeLimit) {
            log.error("Queue size exceeded for task {}", taskType);
            return false;
        }

        TaskId taskId = new TaskId(taskType, taskData.getHostId(), taskData.getTaskId());
        taskQueue.enqueueTask(new TaskRunData(taskId, TaskRunType.PRIMARY, taskData, taskPriority));
        return true;
    }

    public void setPoolSize(int poolSize) {
        this.poolSize = poolSize;
    }

    @Required
    public void setTaskQueue(TaskQueue taskQueue) {
        this.taskQueue = taskQueue;
    }

    @Required
    public void setTaskQueueMetrics(TaskQueueMetrics taskQueueMetrics) {
        this.taskQueueMetrics = taskQueueMetrics;
    }

    @Required
    public void setTaskScheduler(TaskScheduler taskScheduler) {
        this.taskScheduler = taskScheduler;
    }

    private class WorkerRunnable implements Runnable {
        @Override
        public void run() {
            taskQueueMetrics.workerStarted(Thread.currentThread().getName());
            log.info("Start worker");
            while (true) {
                TaskRunData taskRunData;
                try {
                    try {
                        for (WorkerTaskType taskType : WorkerTaskType.values()) {
                            TaskQueueMetrics.TaskMetrics tm = taskQueueMetrics.getTaskMetrics(taskType);
                            if (tm != null) {
                                tm.updateStatsWithLock();
                            }
                        }
                        taskRunData = taskQueue.pollTask();
                    } catch (InterruptedException e) {
                        break;
                    }
                    if (Thread.interrupted()) {
                        break;
                    }
                } catch (Exception e) {
                    log.error("Exception waiting for task", e);
                    break;
                }
                if (taskRunData == null) {
                    break;
                }
                TaskResult tr = TaskResult.FAIL;
                long startTime = System.nanoTime();
                try (MDC.MDCCloseable mdcCloseable =
                        MDC.putCloseable(MDC_TASK_ID_KEY, taskRunData.getTaskId().getTaskUUID().toString()))
                {
                    WorkerTaskData workerTaskData = taskRunData.getTaskData();
                    try {
                        log.info("Start task: type={} hostId={} {}",
                                workerTaskData.getTaskType(),
                                workerTaskData.getHostId(),
                                workerTaskData.getShortDescription());
                        Task task = taskRegistry.getTaskRegistryMap().get(taskRunData.getTaskId().getTaskType());
                        Task.Result result = task.runWithTracer(workerTaskData);
                        tr = result.getTaskResult();
                        log.info("Finish task: type={} hostId={} result={} {} ",
                                workerTaskData.getTaskType(),
                                workerTaskData.getHostId(),
                                tr,
                                workerTaskData.getShortDescription());
                        if (result.getTaskResult() == TaskResult.SUCCESS) {
                            List<WorkerTaskData> followingTasks = result.getFollowingTasks();
                            for (WorkerTaskData followingTask : followingTasks) {
                                // необходимо, чтобы последующие задачи выполнялись с тем же приоритетом
                                taskQueue.enqueueTask(TaskRunData.wrap(followingTask, taskRunData.getTaskPriority()));
                            }
                        }
                    } catch (Throwable e) {
                        log.error("Error in task: type={} hostId={} {}", workerTaskData.getTaskType(),
                                workerTaskData.getHostId(), workerTaskData.getShortDescription(), e);
                    }
                }
                long taskRunTime = System.nanoTime() - startTime;
                taskQueueMetrics.taskFinished(Thread.currentThread().getName(), taskRunData.getTaskId(), taskRunTime, tr);
            }
            log.info("Stop worker");
            taskQueueMetrics.workerStopped(Thread.currentThread().getName());
        }
    }
}
