package ru.yandex.mail.cerberus.asyncdb;

import lombok.val;
import one.util.streamex.StreamEx;
import org.jdbi.v3.sqlobject.customizer.BindMap;
import org.jdbi.v3.sqlobject.customizer.Define;
import org.jdbi.v3.sqlobject.statement.SqlQuery;
import ru.yandex.mail.micronaut.common.Page;
import ru.yandex.mail.micronaut.common.Pageable;
import ru.yandex.mail.cerberus.asyncdb.annotations.BindId;
import ru.yandex.mail.cerberus.asyncdb.annotations.BindIdList;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

import static java.util.function.Predicate.not;

public interface RoCrudRepository<ID, E> extends Repository {
    private static <ID, E> Page<ID, E> makePage(List<E> values, long pageSize, Function<E, ID> idExtractor) {
        final Optional<ID> nextPageId = (values.isEmpty() || values.size() < pageSize)
            ? Optional.empty()
            : Optional.of(idExtractor.apply(values.get(values.size() - 1)));
        return new Page<>(values, nextPageId);
    }

    @SqlQuery("SELECT <entityColumnsList> FROM <table> WHERE (<id>) = (<idBindings>)")
    Optional<E> find(@BindId ID id);

    @SqlQuery("SELECT <entityColumnsList> FROM <table> WHERE (<id>) IN (<ids>)")
    List<E> findAll(@BindIdList Iterable<ID> ids);

    @SqlQuery("SELECT <entityColumnsList> FROM <table>")
    List<E> findAll();

    @SqlQuery("SELECT EXISTS (SELECT <id> FROM <table> WHERE (<id>) = (<idBindings>))")
    boolean exists(@BindId ID id);

    @SqlQuery("SELECT COUNT(*) = :size FROM <table> WHERE (<id>) IN (<ids>)")
    boolean existsAll(@BindIdList Iterable<ID> ids, int size);

    default boolean existsAll(Collection<ID> ids) {
        return existsAll(ids, ids.size());
    }

    @SqlQuery("SELECT EXISTS (SELECT <id> FROM <table> WHERE (<id>) IN (<ids>))")
    boolean existsAny(@BindIdList Iterable<ID> ids);

    @SqlQuery("SELECT COUNT(*) FROM <table>")
    long count();

    @SqlQuery("SELECT <entityColumnsList> FROM <table>\n"
            + "ORDER BY (<id>)\n"
            + "LIMIT :pageSize")
    List<E> findPageItems(int pageSize);

    @SqlQuery("SELECT <entityColumnsList> FROM <table>\n"
            + "WHERE (<id>) > (<idBindings>)\n"
            + "ORDER BY (<id>)\n"
            + "LIMIT :pageSize")
    List<E> findPageItems(@BindId ID pageId, int pageSize);

    @SqlQuery("SELECT <entityColumnsList> FROM <table>\n"
            + "WHERE <condition>\n"
            + "ORDER BY (<id>)\n"
            + "LIMIT :pageSize")
    List<E> findPageItems(int pageSize, @Define String condition, @BindMap Map<String, ?> bindings);

    @SqlQuery("SELECT <entityColumnsList> FROM <table>\n"
            + "WHERE ((<id>) > (<idBindings>)) AND (<condition>)\n"
            + "ORDER BY (<id>)\n"
            + "LIMIT :pageSize")
    List<E> findPageItems(@BindId ID pageId, int pageSize, @Define String condition, @BindMap Map<String, ?> bindings);

    default Page<ID, E> findPage(Pageable<ID> pageable, Function<E, ID> idExtractor) {
        val pageSize = pageable.getPageSize();
        final var values = pageable.getPageId()
            .map(id -> findPageItems(id, pageSize))
            .orElseGet(() -> findPageItems(pageSize));
        return makePage(values, pageable.getPageSize(), idExtractor);
    }

    default Page<ID, E> findPage(Pageable<ID> pageable, Function<E, ID> idExtractor, Condition condition) {
        val sql = condition.getSql();
        val bindings = condition.getBindings();
        val pageSize = pageable.getPageSize();

        final var values = pageable.getPageId()
            .map(id -> findPageItems(id, pageSize, sql, bindings))
            .orElseGet(() -> findPageItems(pageSize, sql, bindings));
        return makePage(values, pageable.getPageSize(), idExtractor);
    }

    @SqlQuery("SELECT <id> FROM <table>\n"
            + "WHERE (<id>) IN (<ids>)")
    Set<ID> findExistingIds(@BindIdList Iterable<ID> ids);

    default Set<ID> findMissingIds(Iterable<ID> ids) {
        val existing = findExistingIds(ids);
        return StreamEx.of(ids.spliterator())
            .filter(not(existing::contains))
            .toImmutableSet();
    }
}
