package ru.yandex.mail.junit_extensions.program_runner;

import lombok.SneakyThrows;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import one.util.streamex.StreamEx;
import org.awaitility.Duration;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;
import org.junit.platform.commons.support.AnnotationSupport;
import ru.yandex.mail.junit_extensions.program_runner.ProgramOptions.JavaProgramOptions;
import ru.yandex.mail.junit_extensions.program_runner.ProgramOptions.NativeProgramOptions;

import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

import static java.lang.ProcessBuilder.Redirect.INHERIT;
import static org.awaitility.Awaitility.await;

@Slf4j
public class ProgramRunnerExtension implements BeforeAllCallback, ParameterResolver {
    @Value
    private static class ProgramKey {
        Class<? extends ProgramRegistry> registryClass;
        String programId;
    }

    private static final Duration DEFAULT_AWAIT_TIMEOUT = new Duration(30, TimeUnit.SECONDS);
    private static final String JAVA_PATH = asString(Paths.get(System.getProperty("java.home"), "bin", "java"));

    private static final Map<ProgramKey, Process> RUNNING_PROGRAMS = new ConcurrentHashMap<>();

    private static String asString(Path path) {
        return path.toAbsolutePath().toString();
    }

    private static void validatePath(Path path) {
        if (!path.toFile().exists()) {
            throw new RuntimeException("Path not found: " + path);
        }
    }

    private static void awaitHealthCheck(Process process, ProgramOptions options) {
        val healthCheck = options.getHealthCheck();
        val timeout = options.getStartTimeout()
            .map(value -> new Duration(value.getSeconds(), TimeUnit.SECONDS))
            .orElse(DEFAULT_AWAIT_TIMEOUT);

        await()
            .atMost(timeout)
            .until(healthCheck::check, HealthStatus.UP::equals);

        if (healthCheck.check() != HealthStatus.UP) {
            process.destroy();
        }
    }

    @SneakyThrows
    private static Process runProcess(List<String> command) {
        val builder = new ProcessBuilder(command)
            .redirectErrorStream(true)
            .redirectError(INHERIT)
            .redirectOutput(INHERIT);

        val process = builder.start();
        Runtime.getRuntime().addShutdownHook(new Thread(process::destroy));
        return process;
    }

    private static Process startJavaProcess(JavaProgramOptions options) {
        options.getJvmOptionsFile().ifPresent(ProgramRunnerExtension::validatePath);
        validatePath(options.getJarPath());

        val command = StreamEx.of(JAVA_PATH)
            .append(options.getJvmOptionsFile()
                .map(ProgramRunnerExtension::asString)
                .map(path -> '@' + path)
                .stream())
            .append(options.getJvmOptions())
            .append("-classpath")
            .append(asString(options.getJarPath()))
            .append(options.getMainClass())
            .distinct()
            .toImmutableList();

        log.info("Run command: {}", command);
        return runProcess(command);
    }

    private static Process startNativeProcess(NativeProgramOptions options) {
        validatePath(options.getBinaryPath());

        val command = StreamEx.of(asString(options.getBinaryPath()))
            .append(options.getCmdOptions())
            .toImmutableList();
        return runProcess(command);
    }

    private static Process startProcess(ProgramOptions options) {
        if (options instanceof JavaProgramOptions) {
            return startJavaProcess((JavaProgramOptions) options);
        } else {
            return startNativeProcess((NativeProgramOptions) options);
        }
    }

    private static <R extends ProgramRegistry> void runProgram(R registry, String id, ProgramOptions options) {
        val key = new ProgramKey(registry.getClass(), id);
        RUNNING_PROGRAMS.computeIfAbsent(key, unused -> {
            val process = startProcess(options);
            awaitHealthCheck(process, options);
            return process;
        });
    }

    private static Class<? extends ProgramRegistry> findRegistryClass(ExtensionContext context) {
        val testClass = context.getRequiredTestClass();
        val registryAnnotation = AnnotationSupport.findAnnotation(testClass, RegisterProgramsRegistry.class)
            .orElseThrow(() -> new IllegalStateException("@RegisterProgramsRegistry annotation not found"));
        return registryAnnotation.value();
    }

    @Override
    public void beforeAll(ExtensionContext context) throws Exception {
        val testClass = context.getRequiredTestClass();
        val registryClass = findRegistryClass(context);
        val registry = registryClass.getDeclaredConstructor().newInstance();

        StreamEx.of(AnnotationSupport.findRepeatableAnnotations(testClass, RunProgram.class))
            .map(RunProgram::name)
            .mapToEntry(id -> {
                return registry.findProgramOptions(id)
                    .orElseThrow(() -> new IllegalArgumentException("Options for " + id + " program not found"));
            })
            .forKeyValue((id, options) -> {
                runProgram(registry, id, options);
            });
    }

    @Override
    public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) {
        return parameterContext.isAnnotated(BindProcess.class)
            && parameterContext.getParameter().getType().equals(Process.class);
    }

    @Override
    public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
        val registryClass = findRegistryClass(extensionContext);

        val programId = parameterContext.findAnnotation(BindProcess.class)
            .orElseThrow(() -> new IllegalStateException("@BindProcess annotation not found"))
            .value();

        val key = new ProgramKey(registryClass, programId);
        val process = RUNNING_PROGRAMS.get(key);
        if (process == null) {
            throw new ParameterResolutionException("Running program " + programId + " not found");
        }

        return process;
    }
}
