package ru.yandex.webmaster3.worker;

import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import com.datastax.driver.core.utils.UUIDs;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.datatype.joda.JodaModule;
import com.fasterxml.jackson.module.paramnames.ParameterNamesModule;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.slf4j.MDC;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.SmartLifecycle;
import org.springframework.scheduling.Trigger;
import org.springframework.scheduling.concurrent.ConcurrentTaskScheduler;
import org.springframework.scheduling.concurrent.CustomizableThreadFactory;
import org.springframework.stereotype.Service;

import ru.yandex.webmaster3.core.util.W3Collectors;
import ru.yandex.webmaster3.core.util.json.JsonMapping;
import ru.yandex.webmaster3.core.worker.WorkerApi;
import ru.yandex.webmaster3.core.worker.task.PeriodicTaskInfo;
import ru.yandex.webmaster3.core.worker.task.PeriodicTaskType;
import ru.yandex.webmaster3.core.worker.task.TaskResult;
import ru.yandex.webmaster3.storage.util.yt.YtCypressService;
import ru.yandex.webmaster3.storage.util.yt.YtException;
import ru.yandex.webmaster3.storage.util.yt.lock.CypressProvider;
import ru.yandex.webmaster3.worker.queue.TaskQueueService;

/**
 * @author aherman
 */
@Slf4j
@Service
@RequiredArgsConstructor(onConstructor_ = {@Autowired})
public class SimpleZkScheduler implements ApplicationContextAware, SmartLifecycle {
    private static final Duration MIN_DURATION_BETWEEN_TASK_RUNS = Duration.standardSeconds(30L);

    private ApplicationContext applicationContext;

    private final CypressProvider cypressProvider;
    private final WorkerApi workerApi;

    private ConcurrentTaskScheduler concurrentTaskScheduler;
    private ScheduledExecutorService scheduledExecutorService;

    private volatile boolean started = false;

    private int availableThreads = 64;
    @Value("${application.hostname}:${application.http.port}")
    private String hostname;
    @Value("${webmaster3.worker.simpleZkScheduler.enabled}")
    private boolean enabled = true;
    private Map<PeriodicTaskType, AtomicBoolean> taskPauseState = Arrays.stream(PeriodicTaskType.values())
            .map(ptt -> new AbstractMap.SimpleEntry<>(ptt, new AtomicBoolean(false)))
            .collect(W3Collectors.toEnumMap(PeriodicTaskType.class));
    private final AtomicInteger runningTasks = new AtomicInteger(0);

    private static ObjectMapper OM = new ObjectMapper()
            .configure(SerializationFeature.WRITE_DATE_KEYS_AS_TIMESTAMPS, false)
            .registerModule(new ParameterNamesModule())
            .registerModule(new JodaModule());

    public void init() {
        CustomizableThreadFactory customizableThreadFactory = new CustomizableThreadFactory("worker-");
        customizableThreadFactory.setDaemon(true);

        scheduledExecutorService = new ScheduledThreadPoolExecutor(availableThreads, customizableThreadFactory);
        concurrentTaskScheduler = new ConcurrentTaskScheduler(scheduledExecutorService);
    }

    @Override
    public void start() {
        if (!enabled) {
            log.warn("Scheduler started in DEV mode, all periodic tasks are disabled");
        }
        Map<String, PeriodicTask> tasks = applicationContext.getBeansOfType(PeriodicTask.class);
        for (Map.Entry<String, PeriodicTask> entry : tasks.entrySet()) {
            String taskName = entry.getKey();
            PeriodicTask task = entry.getValue();

            TaskSchedule schedule = task.getSchedule();
            Trigger trigger = schedule.getTrigger();
            if (trigger != null) {
                log.debug("Found task: {} trigger={} lock={}", taskName, schedule.getTrigger(), task.getLockMode());
                TaskRunner runner = createRunner(task, taskName);
                if (enabled) {
                    log.info("Register task: {} trigger={} lock={}", taskName, schedule.getTrigger(), task.getLockMode());
                    concurrentTaskScheduler.schedule(runner, trigger);
                } else {
                    log.info("Ignore task: {} trigger={}", taskName, schedule.getTrigger());
                }
            } else {
                log.info("Ignore task: {} trigger=NEVER", taskName);
            }
        }

        this.started = true;
    }

    @NotNull
    private TaskRunner createRunner(PeriodicTask task, String taskName) {
        switch (task.getLockMode()) {
            case ONE_PER_CLUSTER:
                return new CypressLockedTaskRunner(taskName, hostname, task, cypressProvider);

            case ONE_PER_SERVER:
                return new TaskRunner(taskName, hostname, task);

            default:
                throw new IllegalArgumentException("Unsupported lock mode: " + taskName + " " + task.getLockMode());
        }
    }

    @Override
    public void stop() {
        scheduledExecutorService.shutdown();
        try {
            if (!scheduledExecutorService.awaitTermination(5, TimeUnit.SECONDS)) {
                log.warn("Unable to shutdown tasks properly");
            }
        } catch (InterruptedException e) {
            log.warn("Unable to shutdown tasks properly", e);
        }
        this.started = false;
    }

    @Override
    public boolean isRunning() {
        return started;
    }

    public void pause(PeriodicTaskType type) {
        taskPauseState.get(type).set(true);
    }

    public void resume(PeriodicTaskType type) {
        taskPauseState.get(type).set(false);
    }

    public boolean isPaused(PeriodicTaskType taskType) {
        return taskPauseState.get(taskType).get();
    }

    public Set<PeriodicTaskType> getPausedTasks() {
        return taskPauseState.entrySet().stream()
                .filter(e -> e.getValue().get())
                .map(Map.Entry::getKey)
                .collect(Collectors.toSet());
    }

    private class TaskRunner implements Runnable {
        protected final String taskName;
        protected final String hostname;

        protected final PeriodicTask task;

        private TaskRunner(String taskName, String hostname, PeriodicTask task) {
            this.taskName = taskName;
            this.hostname = hostname;
            this.task = task;
        }

        @Override
        public void run() {
            if (isPaused(task.getType())) {
                log.info("Task {} is paused. Ignoring", taskName);
                return;
            }
            PeriodicTask.Result taskResult = new PeriodicTask.Result(TaskResult.FAIL);
            long startTime = System.nanoTime();
            UUID runId = UUIDs.timeBased();
            try (MDC.MDCCloseable ign = MDC.putCloseable(TaskQueueService.MDC_TASK_ID_KEY, runId.toString())) {
                try {
                    int runningNow = runningTasks.getAndIncrement();
                    // чтобы worker, который сейчас занят меньше, имел возможность захватить лок первым
                    Thread.sleep(runningNow * 10);
                    log.info("Start task: {}", taskName);
                    taskResult = runTask(runId);
                } catch (Exception e) {
                    log.error("Unable to run task: {}", taskName, e);
                } finally {
                    runningTasks.decrementAndGet();
                }
                long duration = System.nanoTime() - startTime;
                log.info("Finish task: {} result={} timeMs={}", taskName, taskResult.getTaskResult(),
                        TimeUnit.NANOSECONDS.toMillis(duration));
            }
        }

        protected PeriodicTask.Result runTask(UUID runId) throws Exception {
            return task.execute(runId);
        }

    }

    private class CypressLockedTaskRunner extends TaskRunner {
        private final CypressProvider cypressProvider;
        private final String path;

        private CypressLockedTaskRunner(String taskName, String hostname, PeriodicTask task, CypressProvider cypressProvider) {
            super(taskName, hostname, task);
            this.cypressProvider = cypressProvider;
            this.path = getZkLockPath(taskName);
        }

        public PeriodicTask.Result runTask(UUID runId) {
            AtomicReference<PeriodicTask.Result> reference = new AtomicReference<>();
            final TaskRunInfo taskRunInfo = new TaskRunInfo(hostname, DateTime.now());
            if (!isTaskStillRunning(null, taskRunInfo)) {
                cypressProvider.tryRunExclusive(path, java.time.Duration.ofSeconds(5), (cypService) -> {
                    try {


                        if (isTaskStillRunning(cypService, taskRunInfo)) {
                            return;
                        }
                        try {
                            cypressProvider.create(cypService).forPath(nodePath());
                        } catch (YtException ignored) {
                        }
                        final String nodeData = OM.writeValueAsString(taskRunInfo);
                        cypressProvider.setData(cypService).forPath(nodePath(), nodeData);
                        reference.set(task.execute(runId));
                    } catch (Exception e) {
                        log.error("{}", e.getMessage(), e);
                    }

                });

                if (reference.get() != null) {
                    return reference.get();
                }

            }
            logTaskWorker();
            return new PeriodicTask.Result(TaskResult.FAIL);
        }

        private boolean isTaskStillRunning(YtCypressService cypressService, TaskRunInfo taskRunInfo) {
            String oldData;
            try {
                oldData = cypressService == null ? cypressProvider.getData().forPath(nodePath()) : cypressProvider.getData(cypressService).forPath(nodePath());
            } catch (YtException e) {
                log.info("No info node, task was never run before");
                return false;
            }
            if (oldData != null) {
                TaskRunInfo oldTaskRunInfo = JsonMapping.readValue(oldData, TaskRunInfo.class);
                if (oldTaskRunInfo.getStartDate().plus(MIN_DURATION_BETWEEN_TASK_RUNS).isAfter(DateTime.now())) {
                    log.info("Not enough time between task runs");
                    return true;
                }

                String oldHostName = oldTaskRunInfo.getHostname();
                if (oldHostName.equals(taskRunInfo.getHostname())) {
                    log.info("Same host were previously running task {}", taskName);
                } else {
                    // спросим у соотв. хоста, что он уже перестал выполнять таску
                    PeriodicTaskInfo taskInfo = workerApi.getPeriodicTaskInfo("http://" + oldHostName, task.getType());
                    if (taskInfo != null && taskInfo.isRunning()) {
                        log.warn("Host {} is still running task {}", oldHostName, taskName);
                        return true;
                    }
                }
            }
            return false;
        }

        private void logTaskWorker() {
            try {
                final String s = cypressProvider.getData().forPath(nodePath());
                log.info("Task already started: {} {}", taskName, s);
            } catch (YtException exp) {
                log.warn("Task started somewhere, but task run info is gone: {}", taskName);
            }
        }

        private String nodePath() {
            return path;
        }
    }

    private static class TaskRunInfo {
        private final String hostname;
        private final DateTime startDate;

        private TaskRunInfo(String hostname, DateTime startDate) {
            this.hostname = hostname;
            this.startDate = startDate;
        }

        public String getHostname() {
            return hostname;
        }

        public DateTime getStartDate() {
            return startDate;
        }

    }

    @NotNull
    private static String getZkLockPath(String taskName) {
        return "/scheduler/tasks/" + taskName;
    }

    @NotNull
    private static String getZkInfoPath(String taskName) {
        return "/scheduler/tasks/" + taskName + "/info";
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }

    public void setHostname(String hostname) {
        this.hostname = hostname;
    }
}
