package ru.yandex.infra.auth.tasks;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import org.apache.http.HttpStatus;
import org.slf4j.Logger;

import ru.yandex.infra.auth.Role;
import ru.yandex.infra.auth.RoleSubject;
import ru.yandex.infra.auth.idm.api.IdmApiService;
import ru.yandex.infra.auth.idm.api.IdmApiServiceError;
import ru.yandex.infra.auth.nanny.NannyRole;
import ru.yandex.infra.auth.nanny.NannyService;
import ru.yandex.infra.auth.yp.YpObjectsTreeGetterError;
import ru.yandex.infra.auth.yp.YpService;
import ru.yandex.infra.auth.yp.YpServiceReadOnlyImpl;
import ru.yandex.infra.controller.RepeatedTask;
import ru.yandex.infra.controller.concurrent.LeaderService;
import ru.yandex.infra.controller.metrics.GaugeRegistry;
import ru.yandex.infra.controller.metrics.GolovanableGauge;
import ru.yandex.infra.controller.metrics.NamespacedGaugeRegistry;
import ru.yandex.infra.controller.util.ExceptionUtils;
import ru.yandex.infra.controller.util.ExitUtils;
import ru.yandex.misc.lang.StringUtils;

import static java.util.Collections.emptySet;
import static org.slf4j.LoggerFactory.getLogger;
import static ru.yandex.infra.auth.yp.YpGroupsHelper.SYSTEM_IDM_LABEL_VALUE;
import static ru.yandex.infra.auth.yp.YpGroupsHelper.SYSTEM_LABEL_KEY;

public final class RolesUpdater {
    private static final Logger LOG = getLogger(RolesUpdater.class);
    private static final Map<String, String> IDM_LABELS = ImmutableMap.of(SYSTEM_LABEL_KEY, SYSTEM_IDM_LABEL_VALUE);
    private static final String WATCHDOG_ID = "roleUpdater";

    static final String METRIC_YP_OBJECTS_LOAD_TIME_MS = "yp_objects_load_time_ms";
    static final String METRIC_STARTUP_TIME_MS = "startup_time_ms";
    static final String METRIC_TOTAL_ROLES = "total_roles";
    static final String METRIC_NEW_ROLES_COUNT = "new_roles_count";
    static final String METRIC_GARBAGE_ROLES_COUNT = "garbage_roles_count";
    static final String METRIC_MOVED_ROLES_COUNT = "moved_roles_count";
    static final String METRIC_ROLES_ADDED = "roles_added";
    static final String METRIC_ROLES_REMOVED = "roles_removed";
    static final String METRIC_ROLES_RELOCATED = "roles_relocated";
    static final String METRIC_QUEUE_SIZE_TO_ADD_NODE = "queue_size_to_add_node";
    static final String METRIC_QUEUE_SIZE_TO_REMOVE_NODE = "queue_size_to_remove_node";

    private volatile Long metricLastYpObjectsLoadTimeMilliseconds;
    private volatile Long metricStartupTimeMilliseconds;
    private volatile Integer metricTotalRoles = null;
    private volatile Integer metricNewRolesCount = null;
    private volatile Integer metricGarbageRolesCount = null;
    private volatile Integer metricMovedRolesCount = null;
    private final AtomicLong metricRolesAdded = new AtomicLong();
    private final AtomicLong metricRolesRemoved = new AtomicLong();
    private final AtomicLong metricRolesRelocated = new AtomicLong();
    private final AtomicInteger metricQueueSizeToAddNode = new AtomicInteger();
    private final AtomicInteger metricQueueSizeToRemoveNode = new AtomicInteger();

    private final Map<String, String> deployLabels;
    private final ScheduledExecutorService executor;
    private final IdmApiService idmApiService;
    private final YpService ypService;
    private final LeaderService leaderService;
    private final Duration roleUpdateRate;
    private final List<RoleSubject> superUserSubjects;
    private final Watchdog watchdog;
    private final boolean ypGroupsGCEnabled;
    private final boolean ypStageAclCleanupEnabled;
    private final String systemName;
    private final NannyService nannyService;
    private final RepeatedTask mainLoopTask;
    private final boolean processNannyServicesWithoutProject;

    private boolean ypGroupsSynced = false;
    private Set<Role> currentRoleNodes;

    public RolesUpdater(IdmApiService idmApiService,
            YpService ypService,
            NannyService nannyService,
            LeaderService leaderService,
            Duration roleUpdateRate,
            List<RoleSubject> superUserSubjects,
            Duration watchdogTimeout,
            Watchdog watchdog,
            boolean ypGroupsGCEnabled,
            boolean ypStageAclCleanupEnabled,
            boolean processNannyServicesWithoutProject,
            String systemName,
            GaugeRegistry gaugeRegistry) {
        this.idmApiService = idmApiService;
        this.ypService = ypService;
        this.leaderService = leaderService;
        this.roleUpdateRate = roleUpdateRate;
        this.executor = Executors.newSingleThreadScheduledExecutor(runnable -> new Thread(runnable, "role_nodes_sync"));
        this.superUserSubjects = superUserSubjects;
        this.watchdog = watchdog;
        this.ypGroupsGCEnabled = ypGroupsGCEnabled;
        this.ypStageAclCleanupEnabled = ypStageAclCleanupEnabled;
        this.systemName = systemName;
        this.deployLabels = ImmutableMap.of(SYSTEM_LABEL_KEY, systemName);
        this.nannyService = nannyService;
        this.processNannyServicesWithoutProject = processNannyServicesWithoutProject;

        watchdog.register(WATCHDOG_ID, watchdogTimeout.toMillis());

        gaugeRegistry.add(METRIC_TOTAL_ROLES, new GolovanableGauge<>(() -> metricTotalRoles, "axxx"));

        GaugeRegistry nodesRegistry = new NamespacedGaugeRegistry(gaugeRegistry, "nodes");

        nodesRegistry.add(METRIC_YP_OBJECTS_LOAD_TIME_MS, new GolovanableGauge<>(() -> metricLastYpObjectsLoadTimeMilliseconds, "axxx"));
        nodesRegistry.add(METRIC_STARTUP_TIME_MS, new GolovanableGauge<>(() -> metricStartupTimeMilliseconds, "axxx"));
        nodesRegistry.add(METRIC_NEW_ROLES_COUNT, new GolovanableGauge<>(() -> metricNewRolesCount, "axxx"));
        nodesRegistry.add(METRIC_GARBAGE_ROLES_COUNT, new GolovanableGauge<>(() -> metricGarbageRolesCount, "axxx"));
        nodesRegistry.add(METRIC_MOVED_ROLES_COUNT, new GolovanableGauge<>(() -> metricMovedRolesCount, "axxx"));
        nodesRegistry.add(METRIC_ROLES_ADDED, new GolovanableGauge<>(metricRolesAdded::get, "dmmm"));
        nodesRegistry.add(METRIC_ROLES_REMOVED, new GolovanableGauge<>(metricRolesRemoved::get, "dmmm"));
        nodesRegistry.add(METRIC_ROLES_RELOCATED, new GolovanableGauge<>(metricRolesRelocated::get, "dmmm"));
        nodesRegistry.add(METRIC_QUEUE_SIZE_TO_ADD_NODE, new GolovanableGauge<>(metricQueueSizeToAddNode::get, "axxx"));
        nodesRegistry.add(METRIC_QUEUE_SIZE_TO_REMOVE_NODE, new GolovanableGauge<>(metricQueueSizeToRemoveNode::get, "axxx"));

        mainLoopTask = new RepeatedTask(this::mainLoop,
                roleUpdateRate,
                watchdogTimeout.multipliedBy(2),
                executor,
                Optional.of(nodesRegistry),
                LOG,
                true);
    }

    public void start() {
        if (roleUpdateRate.isZero()) {
            LOG.warn("Roles updater is disabled by config option 'main.roles_update_rate' = 0");
        } else {
            init();
            mainLoopTask.start();
        }
    }

    public void shutdown() {
        mainLoopTask.stop();
        executor.shutdown();
    }

    public void init() {
        long startTimeMillis = System.currentTimeMillis();
        try {
            if (!ypGroupsSynced) {
                watchdog.touch(WATCHDOG_ID);
                syncYpGroups();
                ypGroupsSynced = true;
            }

            if (!leaderService.isProcessingAllowed()) {

                watchdog.touch(WATCHDOG_ID);
                for (RoleSubject superUserSubject : superUserSubjects) {
                    try {
                        idmApiService.addRoleSubject(superUserSubject);
                    } catch (IdmApiServiceError ex) {
                        LOG.warn("Failed to grant: {}", superUserSubject);
                    }
                }

                watchdog.suspend(WATCHDOG_ID);
                LOG.info("Starting loading of IDM role nodes");
                long start = System.currentTimeMillis();
                currentRoleNodes = idmApiService.getRoleNodes();
                final long rolesLoadTimeMs = System.currentTimeMillis() - start;
                LOG.info("Loaded {} IDM role nodes in {} ms ({} nodes/s)",
                        currentRoleNodes.size(),
                        rolesLoadTimeMs,
                        1000L * currentRoleNodes.size() / (rolesLoadTimeMs + 1)
                );
                leaderService.allowProcessing();
            }
        } catch (Throwable exception) {
            LOG.error("Failed to initialize roles updater. Forcing shutdown.", exception);
            ExitUtils.gracefulExit(ExitUtils.INIT_FAILURE);
        } finally {
            metricStartupTimeMilliseconds = System.currentTimeMillis() - startTimeMillis;
        }
    }

    public void mainLoop() {
        try {
            if (!leaderService.isLeader()) {
                LOG.error("Leadership lock lost, performing forced shutdown");
                ExitUtils.gracefulExit(ExitUtils.LEADERSHIP_LOST);
            }

            if (!leaderService.isProcessingAllowed()) {
                LOG.error("Roles updater was not initialized. Forcing shutdown.");
                ExitUtils.gracefulExit(ExitUtils.INIT_FAILURE);
            }

            // Start watchdog
            watchdog.touch(WATCHDOG_ID);
            updateRoleNodes();
        } catch (Exception exception) {
            throw new RuntimeException("Role nodes sync cycle failed", exception);
        }
    }

    private Map<String, Role> getMapOfRolesByUuid(Collection<Role> roles) {
        return roles.stream()
                .filter(role -> StringUtils.isNotBlank(role.getUniqueId()))
                .collect(Collectors.toMap(Role::getUniqueId, role -> role));
    }

    private void updateRoleNodes() throws YpObjectsTreeGetterError {

        long startTimeMillis = System.currentTimeMillis();
        Set<Role> roleNodes = ypService.getRoles();
        metricLastYpObjectsLoadTimeMilliseconds = System.currentTimeMillis() - startTimeMillis;

        Set<Role> idmNodes = currentRoleNodes;
        //Ignoring both: new and absent nanny role nodes if nanny roles processing is disabled by config
        if (!nannyService.isNodesSyncEnabled()) {
            idmNodes = currentRoleNodes.stream()
                    .filter(role -> !role.isNannyRole())
                    .collect(Collectors.toSet());
            roleNodes = roleNodes.stream()
                    .filter(role -> !role.isNannyRole())
                    .collect(Collectors.toSet());
        }

        HashSet<Role> roleNodesToRemove = new HashSet<>(Sets.difference(idmNodes, roleNodes));
        HashSet<Role> roleNodesToAdd = new HashSet<>(Sets.difference(roleNodes, idmNodes));

        metricNewRolesCount = roleNodesToAdd.size();
        metricGarbageRolesCount = roleNodesToRemove.size();

        appendRoleSubjectsIntoNewNannyRoles(roleNodesToAdd, roleNodesToRemove);

        Map<String, Role> rolesToAddByUuid = getMapOfRolesByUuid(roleNodesToAdd);
        Map<String, Role> rolesWithChangedProjectByUuid = roleNodesToRemove.stream()
                .filter(role -> rolesToAddByUuid.containsKey(role.getUniqueId()))
                .collect(Collectors.toMap(Role::getUniqueId, role -> role));
        metricMovedRolesCount = rolesWithChangedProjectByUuid.size();
        //We should remove roles in one transaction with adding the same roles at different tree location (other project)
        rolesWithChangedProjectByUuid.values().forEach(roleNodesToRemove::remove);

        if (currentRoleNodes.isEmpty()) {
            // Add ROOT role
            roleNodesToAdd.add(Role.empty());
        }

        Set<Role> removedRoles = new HashSet<>();
        Set<Role> addedRoles = new HashSet<>();

        if (!roleNodesToAdd.isEmpty()) {
            LOG.info("Role nodes to add to IDM: {}", roleNodesToAdd.size());
            addIdmRoleNodes(roleNodesToAdd, rolesWithChangedProjectByUuid, addedRoles, removedRoles);
        }

        if (!roleNodesToRemove.isEmpty()) {
            LOG.info("Role nodes to remove from IDM: {}", roleNodesToRemove.size());
            Set<Role> removed = removeIdmRoleNodes(roleNodesToRemove);
            removedRoles.addAll(removed);
        }

        currentRoleNodes.removeAll(removedRoles);
        metricRolesRemoved.addAndGet(removedRoles.size());

        currentRoleNodes.addAll(addedRoles);
        metricRolesAdded.addAndGet(addedRoles.size());

        cleanStageAclsAfterMovingIntoAnotherProject(removedRoles);
        removeYpGroupsForRemovedRoleNodes(removedRoles, addedRoles);

        metricTotalRoles = currentRoleNodes.size();
    }

    private void appendRoleSubjectsIntoNewNannyRoles(Set<Role> roleNodesToAdd, Collection<Role> roleNodesToRemove) {
        //We should download role subjects (users/groups) for all new nanny service roles (skip services, moved into another project)
        Map<String, Role> removedNannyRoleNodesByUuid = roleNodesToRemove.stream()
                .filter(Role::isNannyServiceRootNode)
                .collect(Collectors.toMap(Role::getUniqueId, role -> role));

        List<Role> newServices = new ArrayList<>();
        List<Role> oldServicesMovedIntoAnotherProject = new ArrayList<>();

        roleNodesToAdd.stream()
                .filter(Role::isNannyServiceRootNode)
                .forEach(role -> {
                    if (removedNannyRoleNodesByUuid.containsKey(role.getUniqueId())) { //removed and created at the same time
                        oldServicesMovedIntoAnotherProject.add(role);
                    } else {
                        newServices.add(role);
                    }
                });

        oldServicesMovedIntoAnotherProject.forEach(role -> {
            String previousProjectId = removedNannyRoleNodesByUuid.get(role.getUniqueId()).getProjectId().orElse(null);
            String serviceId = role.getNannyServiceId().orElseThrow();
            String commitMessage = String.format("project_id was changed from '%s' into '%s'",
                    previousProjectId,
                    role.getProjectId().orElse(null));
            LOG.info("Nanny service '{}' {}. Scheduling auth_attrs sync from YP into Nanny.", serviceId, commitMessage);
            nannyService.syncNannyServiceAuthAttrsAsync(serviceId, commitMessage);
        });

        newServices.forEach(nannyServiceNode -> {
            String serviceId = nannyServiceNode.getNannyServiceId().orElseThrow();
            try {
                roleNodesToAdd.removeIf(role -> role.getNannyServiceId().orElse("").equals(serviceId));

                LOG.info("Importing initial roles for nanny service '{}' from Nanny into YP", serviceId);
                String projectId = nannyServiceNode.getProjectId().orElse(null);
                Set<NannyRole> rolesFromNannyAuthAttrs = nannyService.getRolesWithSubjects(projectId, serviceId, nannyServiceNode.getUniqueId());
                rolesFromNannyAuthAttrs.forEach(nannyService::syncNannyServiceRoleIntoYP);

                roleNodesToAdd.add(new Role(Role.getRolePathForNannyService(projectId, serviceId), "", "", nannyServiceNode.getUniqueId()));
                roleNodesToAdd.addAll(rolesFromNannyAuthAttrs);

                if (projectId != null) {
                    String commitMessage = String.format("Initial sync of project '%s' roles", projectId);
                    LOG.info("Nanny service '{}' {} from YP into Nanny", serviceId, commitMessage);
                    nannyService.syncNannyServiceAuthAttrsAsync(serviceId, commitMessage);
                }
            } catch (Exception exception) {
                //Will try to add roles on the next sync cycle again.
                //Skip them on current cycle.
                if (ExceptionUtils.tryExtractHttpErrorCode(exception).orElse(-1) == HttpStatus.SC_NOT_FOUND) {
                    LOG.error("Initial nanny service '{}' sync failed: nanny_service object is present in YP, but missed in Nanny", serviceId);
                } else {
                    LOG.error("Initial nanny service '{}' sync failed", serviceId, exception);
                }
            }
        });
    }

    private void cleanStageAclsAfterMovingIntoAnotherProject(Collection<Role> removedNodes) {
        if (!ypStageAclCleanupEnabled || removedNodes.isEmpty()) {
            return;
        }

        watchdog.touch(WATCHDOG_ID);

        Set<String> allStages = currentRoleNodes.stream()
            .map(role -> role.getStageId().orElse(null))
            .filter(Objects::nonNull)
            .collect(Collectors.toSet());

        Map<String, Set<Role>> removedRolesByStageId = removedNodes.stream()
                .filter(role -> role.getStageId().map(allStages::contains).orElse(false))
                .collect(Collectors.groupingBy(role -> role.getStageId().orElseThrow(), Collectors.toSet()));

        removedRolesByStageId.forEach(ypService::removeRolesFromStageAcl);

        watchdog.touch(WATCHDOG_ID);
    }

    private void removeYpGroupsForRemovedRoleNodes(Collection<Role> removedRoles, Collection<Role> addedRoles) {

        if (!ypGroupsGCEnabled || removedRoles.isEmpty()) {
            return;
        }

        watchdog.touch(WATCHDOG_ID);

        Set<String> groupIdsToAdd = addedRoles.stream()
                .map(role -> YpServiceReadOnlyImpl.getYpGroupName(role, systemName))
                .filter(Objects::nonNull)
                .collect(Collectors.toSet());

        Set<String> groupIdsToRemove =  removedRoles.stream()
                .filter(role -> !(role.isNannyRole() && role.getProjectId().isEmpty() && !processNannyServicesWithoutProject))
                .map(role -> YpServiceReadOnlyImpl.getYpGroupName(role, systemName))
                .filter(groupId -> groupId != null && !groupIdsToAdd.contains(groupId))
                .collect(Collectors.toSet());

        LOG.info("Removing YP groups for removed role nodes: {}", removedRoles);
        groupIdsToRemove.forEach(groupId -> {
            ypService.removeGroup(groupId);
            watchdog.touch(WATCHDOG_ID);
        });

        watchdog.touch(WATCHDOG_ID);
    }

    private void syncYpGroups() {
        LOG.info("YP groups sync started");
        long start = System.currentTimeMillis();
        ypService.syncGroupMembersToAllSlaveClusters(deployLabels);
        ypService.syncGroupMembersToAllSlaveClusters(IDM_LABELS);
        LOG.info("YP groups sync took {} ms", System.currentTimeMillis() - start);
    }

    private void addIdmRoleNodes(Collection<Role> nodesToAdd,
                                 Map<String, Role> rolesWithChangedProjectByUuid,
                                 Set<Role> addedRoles,
                                 Set<Role> removedRoles) {
        if (nodesToAdd.isEmpty()) {
            return;
        }

        List<List<Role>> batches = createBatches(nodesToAdd);
        metricQueueSizeToAddNode.set(batches.size());

        int addedBeforeTransaction = addedRoles.size();
        int removedBeforeTransaction = removedRoles.size();

        for(List<Role> batch: batches) {
            if (sendBatchToIdm(batch, rolesWithChangedProjectByUuid)) {
                addedRoles.addAll(batch);
                batch.stream()
                        .map(role -> rolesWithChangedProjectByUuid.get(role.getUniqueId()))
                        .filter(Objects::nonNull)
                        .forEach(removedRoles::add);
            }
            metricQueueSizeToAddNode.decrementAndGet();
        }

        int movedIntoAnotherProject = removedRoles.size() - removedBeforeTransaction;
        int addedCount = addedRoles.size() - addedBeforeTransaction - movedIntoAnotherProject;

        if (addedCount != 0) {
            LOG.info("Added {} new IDM role nodes", addedCount);
        }

        if (movedIntoAnotherProject != 0) {
            metricRolesRelocated.addAndGet(movedIntoAnotherProject);
            LOG.info("Relocated {} IDM role nodes", movedIntoAnotherProject);
        }
    }

    @VisibleForTesting
    static List<List<Role>> createBatches(Collection<Role> nodesToAdd) {
        List<List<Role>> batches = new ArrayList<>();

        Map<Integer, TreeSet<Role>> sortedRolesByLevel = nodesToAdd.stream()
                .collect(Collectors.groupingBy(Role::getLevelIgnoringLeaf, Collectors.toCollection(TreeSet::new)));

        for (int depth = 0; depth < Role.MAX_DEPLOY_ROLE_DEPTH; depth++) {
            TreeSet<Role> roles = sortedRolesByLevel.get(depth);
            if (roles != null) {
                List<Role> batch = new ArrayList<>();
                Role parentRole = null;

                for (Role role : roles) {
                    if (parentRole == null || !role.contains(parentRole)) {
                        // Start new batch
                        if (!batch.isEmpty()) {
                            batches.add(batch);
                            batch = new ArrayList<>();
                        }
                        parentRole = role;
                    }
                    batch.add(role);
                }

                // Add last batch
                if (!batch.isEmpty()) {
                    batches.add(batch);
                }

            }
        }
        return batches;
    }

    private boolean sendBatchToIdm(List<Role> batchedRoles, Map<String, Role> rolesWithChangedProjectByUuid) {
        try {
            LOG.info("Adding IDM nodes for: {}", batchedRoles);
            idmApiService.addRoleNodes(batchedRoles, rolesWithChangedProjectByUuid);
            watchdog.touch(WATCHDOG_ID);
        } catch (IdmApiServiceError ex) {
            LOG.error("Failed to add IDM nodes:", ex);
            return false;
        }

        return true;
    }

    //returns actually removed roles
    private Set<Role> removeIdmRoleNodes(Collection<Role> nodesToRemove) {
        metricQueueSizeToRemoveNode.set(nodesToRemove.size());

        if (nodesToRemove.isEmpty()) {
            return emptySet();
        }

        Set<Role> removedRoles = new HashSet<>();
        Optional<Role> parentRole = Optional.empty();

        List<Role> sortedRoles = nodesToRemove.stream().sorted().collect(Collectors.toList());
        for (Role role : sortedRoles) {
            // ignore root role
            if (role.isRoot()) {
                metricQueueSizeToRemoveNode.decrementAndGet();
                continue;
            }
            // filter out nodes if we have removed parent one already
            if (parentRole.isEmpty() || !role.contains(parentRole.get())) {
                LOG.info("Remove role node: {}", role.getExtendedDescription());
                try {
                    idmApiService.removeRoleNode(role);
                    removedRoles.add(role);
                    parentRole = Optional.of(role);
                    watchdog.touch(WATCHDOG_ID);
                } catch (IdmApiServiceError ex) {
                    LOG.error("Failed to remove {}: {}", role.getExtendedDescription(), ex.getMessage());
                }
            } else {
                removedRoles.add(role);
                LOG.info("Skip role node: {}", role.getExtendedDescription());
            }
            metricQueueSizeToRemoveNode.decrementAndGet();
        }

        LOG.info("Removed {} IDM role nodes", removedRoles.size());
        return removedRoles;
    }
}
