package ru.yandex.partner.libs.rbac.userrole;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.jooq.DSLContext;
import org.jooq.InsertValuesStep2;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Repository;
import org.springframework.transaction.annotation.Transactional;

import ru.yandex.partner.core.role.Role;
import ru.yandex.partner.dbschema.partner.tables.records.UserRoleRecord;
import ru.yandex.partner.libs.rbac.role.RoleService;
import ru.yandex.partner.libs.rbac.role.RoleSet;

import static com.google.common.base.Preconditions.checkNotNull;
import static ru.yandex.partner.dbschema.partner.Tables.USER_ROLE;

@Repository
public class UserRoleRepositoryImpl implements UserRoleRepository {
    private static final Logger LOGGER = LoggerFactory.getLogger(UserRoleRepositoryImpl.class);

    private final DSLContext dslContext;
    private final RoleService roleService;

    @Autowired
    public UserRoleRepositoryImpl(DSLContext dslContext, RoleService roleService) {
        this.dslContext = dslContext;
        this.roleService = roleService;
    }

    @Override
    public Map<Long, Set<Role>> fetchUserRolesMap(Collection<Long> userIds) {
        HashMap<Long, Set<Role>> rolesMap = Maps.newHashMapWithExpectedSize(userIds.size());

        for (var userId : userIds) {
            rolesMap.put(userId, Set.of());
        }

        this.dslContext.select()
                .from(USER_ROLE)
                .where(USER_ROLE.USER_ID.in(userIds))
                .fetchGroups(USER_ROLE.USER_ID, USER_ROLE.ROLE_ID)
                .forEach((userId, roleIds) -> rolesMap.put(userId.longValue(), mapRoleIdsToRoles(userId, roleIds)));

        return rolesMap;
    }

    @Override
    public Set<Role> fetchUserRoles(long userId) {
        return fetchUserRolesMap(List.of(userId)).get(userId);
    }

    @Transactional
    @Override
    public void updateUserRoleIds(long userId, Set<Role> roles) {
        checkNotNull(roles);

        this.roleService.checkRolesExists(roles);

        this.dslContext.deleteFrom(USER_ROLE).where(USER_ROLE.USER_ID.eq(userId)).execute();

        InsertValuesStep2<UserRoleRecord, Long, Long> insertStatement =
                this.dslContext.insertInto(USER_ROLE, USER_ROLE.USER_ID, USER_ROLE.ROLE_ID);

        roles.stream().map(Role::getRoleId).forEach(v -> insertStatement.values(userId, v));
        insertStatement.execute();
    }

    private Set<Role> mapRoleIdsToRoles(Long userId, List<Long> roleIds) {
        HashSet<Role> roles = Sets.newHashSetWithExpectedSize(roleIds.size());

        for (Long roleId : roleIds) {
            Optional<Role> role = RoleSet.getRoleById(roleId);
            if (role.isEmpty()) {
                LOGGER.warn("User {} has non-existent role {}", userId, roleId);
            } else {
                roles.add(role.get());
            }
        }

        return roles;
    }
}
