package ru.yandex.mail.pglocal.junit_jupiter;

import lombok.val;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;
import org.junit.platform.commons.support.AnnotationSupport;
import ru.yandex.mail.pglocal.Database;
import ru.yandex.mail.pglocal.MigrationSource;
import ru.yandex.mail.pglocal.MigrationSource.ResourceFolder;

import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

public class PgLocalExtension implements ParameterResolver, BeforeAllCallback {
    private static final Map<String, Database> databases = new ConcurrentHashMap<>();

    private static void ifTestAnnotatedWithInitDb(ExtensionContext context, Consumer<InitDb> action) {
        val testClass = context.getRequiredTestClass();
        AnnotationSupport.findAnnotation(testClass, InitDb.class).ifPresent(action);
    }

    private static Database runDatabase(InitDb annotation, ExtensionContext extensionContext) {
        val migrationSource = annotation.migration().isEmpty()
            ? Optional.<MigrationSource>empty()
            : Optional.of((MigrationSource) new ResourceFolder(annotation.migration()));
        val dbName = annotation.name().isEmpty()
            ? extensionContext.getUniqueId()
            : annotation.name();

        return databases.computeIfAbsent(dbName, name -> {
            val server = ServerManager.getServer();
            return server.createDatabase(dbName, server.getOptions().getUser(), migrationSource);
        });
    }

    @Override
    public boolean supportsParameter(ParameterContext parameterContext,
                                     ExtensionContext extensionContext) throws ParameterResolutionException {
        return parameterContext.isAnnotated(InitDb.class)
            && parameterContext.getParameter().getType() == Database.class;
    }

    @Override
    public Object resolveParameter(ParameterContext parameterContext,
                                   ExtensionContext extensionContext) throws ParameterResolutionException {
        val annotation = parameterContext.getParameter().getAnnotation(InitDb.class);
        return runDatabase(annotation, extensionContext);
    }

    @Override
    public void beforeAll(ExtensionContext context) {
        ServerManager.startServer();
        Runtime.getRuntime().addShutdownHook(new Thread(() -> {
            try {
                ServerManager.stopServer();
            } catch (Throwable ignored) {
            }
        }));
        ifTestAnnotatedWithInitDb(context, annotation -> runDatabase(annotation, context));
    }
}
