package ru.yandex.mail.cerberus.dao.user;

import lombok.val;
import one.util.streamex.EntryStream;
import org.jdbi.v3.sqlobject.config.KeyColumn;
import org.jdbi.v3.sqlobject.config.ValueColumn;
import org.jdbi.v3.sqlobject.customizer.Bind;
import org.jdbi.v3.sqlobject.statement.GetGeneratedKeys;
import org.jdbi.v3.sqlobject.statement.SqlBatch;
import ru.yandex.mail.cerberus.GroupId;
import ru.yandex.mail.cerberus.GroupType;
import ru.yandex.mail.cerberus.asyncdb.util.OneToMany;
import ru.yandex.mail.micronaut.common.IterableStream;
import ru.yandex.mail.cerberus.RoleId;
import ru.yandex.mail.cerberus.Uid;
import ru.yandex.mail.cerberus.asyncdb.CrudRepository;
import ru.yandex.mail.cerberus.asyncdb.annotations.ConfigureCrudRepository;

import java.util.Collection;
import java.util.Map;
import java.util.Set;

import static java.util.Collections.singleton;
import static java.util.Collections.singletonList;

@ConfigureCrudRepository(table = "cerberus.users")
public interface UserRepository extends RoUserRepository, CrudRepository<Uid, UserEntity> {
    @SqlBatch("INSERT INTO cerberus.user_groups (uid, group_id, group_type)\n"
            + "SELECT :uid, id, :groupType\n"
            + "FROM unnest(ARRAY[:groupIds]) as ids (id)\n"
            + "ON CONFLICT DO NOTHING")
    void addToGroup(Iterable<Uid> uid, GroupType groupType, @Bind Iterable<? extends Iterable<GroupId>> groupIds);

    default void addToGroup(GroupType groupType, Map<Uid, ? extends Iterable<GroupId>> usersGroupsIds) {
        val uids = new IterableStream<>(usersGroupsIds, EntryStream::keys);
        val groupIds = new IterableStream<>(usersGroupsIds, EntryStream::values);
        addToGroup(uids, groupType, groupIds);
    }

    default void addToGroup(Uid uid, GroupType groupType, Collection<GroupId> groupIds) {
        addToGroup(singletonList(uid), groupType, singletonList(groupIds));
    }

    default void addToGroup(Uid uid, GroupId groupId, GroupType groupType) {
        addToGroup(uid, groupType, singletonList(groupId));
    }

    @ValueColumn("uid")
    @KeyColumn("group_id")
    @GetGeneratedKeys("uid, group_id")
    @SqlBatch("DELETE FROM cerberus.user_groups\n"
            + "WHERE uid = :uid AND group_type = :groupType AND group_id = ANY(:groupIds)\n"
            + "RETURNING uid, group_id")
    OneToMany<GroupId, Uid> removeFromGroup(Iterable<Uid> uid, GroupType groupType, @Bind Iterable<? extends Iterable<GroupId>> groupIds);

    default Map<GroupId, Set<Uid>> removeFromGroup(GroupType groupType, Map<Uid, ? extends Iterable<GroupId>> usersGroups) {
        val uids = new IterableStream<>(usersGroups, EntryStream::keys);
        val groupIds = new IterableStream<>(usersGroups, EntryStream::values);
        return removeFromGroup(uids, groupType, groupIds).getSetMapping();
    }

    default Set<GroupId> removeFromGroup(Uid uid, GroupType groupType, Iterable<GroupId> groupIds) {
        return removeFromGroup(singletonList(uid), groupType, singletonList(groupIds))
            .getMapping()
            .keySet();
    }

    default boolean removeFromGroup(Uid uid, GroupId groupId, GroupType groupType) {
        return removeFromGroup(uid, groupType, singletonList(groupId))
            .contains(groupId);
    }

    @SqlBatch("INSERT INTO cerberus.user_roles (uid, role_id)\n"
            + "SELECT :uid, id\n"
            + "FROM unnest(ARRAY[:roleIds]) as ids (id)\n"
            + "ON CONFLICT DO NOTHING")
    void attachRoles(Iterable<Uid> uid, @Bind Iterable<? extends Iterable<RoleId>> roleIds);

    default void attachRoles(Map<Uid, Set<RoleId>> usersRoles) {
        val uids = new IterableStream<>(usersRoles, EntryStream::keys);
        val roles = new IterableStream<>(usersRoles, EntryStream::values);
        attachRoles(uids, roles);
    }

    default void attachRoles(Uid uid, Set<RoleId> roleIds) {
        attachRoles(singletonList(uid), singletonList(roleIds));
    }

    default void attachRole(Uid uid, RoleId roleId) {
        attachRoles(uid, singleton(roleId));
    }

    @SqlBatch("DELETE FROM cerberus.user_roles\n"
            + "WHERE uid = :uid AND role_id = ANY(:roleIds)")
    void detachRoles(Iterable<Uid> uid, @Bind Iterable<? extends Iterable<RoleId>> roleIds);

    default void detachRoles(Map<Uid, Set<RoleId>> usersRoles) {
        val uids = new IterableStream<>(usersRoles, EntryStream::keys);
        val roles = new IterableStream<>(usersRoles, EntryStream::values);
        detachRoles(uids, roles);
    }

    default void detachRoles(Uid uid, Set<RoleId> roleId) {
        detachRoles(singletonList(uid), singletonList(roleId));
    }

    default void detachRole(Uid uid, RoleId roleId) {
        detachRoles(uid, singleton(roleId));
    }
}
