package ru.yandex.infra.auth.servlets;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.bolts.collection.Try;
import ru.yandex.infra.auth.Role;
import ru.yandex.infra.auth.yp.YpClients;
import ru.yandex.infra.auth.yp.YpService;
import ru.yandex.infra.auth.yp.YpServiceReadOnlyImpl;
import ru.yandex.infra.controller.dto.StageMeta;
import ru.yandex.infra.controller.util.YsonUtils;
import ru.yandex.infra.controller.yp.Paths;
import ru.yandex.infra.controller.yp.SelectedObjects;
import ru.yandex.inside.yt.kosher.impl.ytree.YTreeProtoUtils;
import ru.yandex.inside.yt.kosher.impl.ytree.builder.YTreeBuilder;
import ru.yandex.yp.YpRawObjectService;
import ru.yandex.yp.client.api.AccessControl;
import ru.yandex.yp.model.YpObjectType;
import ru.yandex.yp.model.YpObjectUpdate;
import ru.yandex.yp.model.YpSetUpdate;
import ru.yandex.yp.model.YpTypedId;

import static java.util.Collections.emptyMap;
import static ru.yandex.infra.controller.util.YpUtils.CommonSelectors.META;

public class AclServlet extends HttpServlet {
    private static final Logger LOG = LoggerFactory.getLogger(AclServlet.class);
    private final YpService ypService;

    public AclServlet(YpService ypService) {
        this.ypService = ypService;
    }

    enum CheckType {
        ABSENT_ROLE("role"),
        MISSED_GROUP("group"),
        WRONG_PROJECT_OR_STAGE("stage");

        final private String name;

        CheckType(String name) {
            this.name = name;
        }
    }

    @Override
    public void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
        if (!"true".equals(req.getParameter("fix"))) {
            resp.getWriter().println("Fixing acls is not allowed by request parameter 'fix' = false");
            return;
        }

        checkAcls(req, resp, true);
    }

    @Override
    public void doGet(HttpServletRequest req, HttpServletResponse resp) {
        checkAcls(req, resp, false);
    }

    private Set<CheckType> getCheckTypesFromArguments(HttpServletRequest req, PrintWriter writer) {
        Map<String, CheckType> checkTypes = Arrays.stream(CheckType.values())
                .collect(Collectors.toMap(type -> type.name, Function.identity()));
        String checksParameter = req.getParameter("checks");
        if (checksParameter == null) {
            return new HashSet<>(checkTypes.values());
        }

        Set<CheckType> checks = new HashSet<>();
        for (String check: checksParameter.split(",")) {
            var checkType = checkTypes.get(check);
            if (checkType == null) {
                writer.println(String.format("Wrong check type '%s'. Supported types: %s", check, checkTypes.keySet()));
                return Collections.emptySet();
            }
            checks.add(checkType);
        }
        return checks;
    }

    private void checkAcls(HttpServletRequest req, HttpServletResponse resp, boolean isAllowedToRemove) {
        final long start = System.currentTimeMillis();

        try {

            String stagesParameter = req.getParameter("stages");
            final Set<String> stagesAllowedForProcessing = stagesParameter != null ?
                    Arrays.stream(stagesParameter.split(",")).collect(Collectors.toSet()) :
                    Collections.emptySet();

            PrintWriter writer = resp.getWriter();

            Set<CheckType> checks = getCheckTypesFromArguments(req, writer);
            if (checks.isEmpty()) {
                return;
            }

            YpClients clients = ypService.getMasterClusterClients();

            LOG.info("Loading all role nodes from yp...");
            Set<Role> allRoles = ypService.getRoles();
            LOG.info("Loaded {} roles (nodes)", allRoles.size());

            Set<String> allGroups = clients.getGroupsClient().listAllIds().get();
            LOG.info("Loaded {} group names from master", allGroups.size());

            List<StageMeta> listOfStageMeta = clients.getStageRepository().selectObjects(META, emptyMap())
                    .thenApply(SelectedObjects::getObjects)
                    .thenApply(objects -> objects.values().stream()
                                .filter(Try::isSuccess)
                                .map(t -> t.get().getMeta())
                                .filter(meta -> stagesAllowedForProcessing.isEmpty() || stagesAllowedForProcessing.contains(meta.getId()))
                                .collect(Collectors.toList())
                    ).get();

            LOG.info("Loaded {} stages", listOfStageMeta.size());

            String systemName = ypService.getSystemName();
            String ypGroupPrefix = systemName + YpServiceReadOnlyImpl.YP_GROUP_NAME_DELIMITER;

            AtomicInteger metricStagesWithErrorsInACL = new AtomicInteger();
            AtomicInteger metricFixedStages = new AtomicInteger();
            for (StageMeta meta: listOfStageMeta) {
                String stageId = meta.getId();
                String projectId = meta.getProjectId();
                List<AccessControl.TAccessControlEntry> acl = meta.getAcl().getEntries();

                List<AccessControl.TAccessControlEntry> validAces = acl.stream().filter(ace -> {
                    List<String> wrongSubjects = ace.getSubjectsList().stream()
                            .filter(subject -> subject.startsWith(ypGroupPrefix))
                            .filter(subject -> !isValidRoleSubject(writer, systemName, checks, allRoles, allGroups, stageId, projectId, subject))
                            .collect(Collectors.toList());
                    return wrongSubjects.size() == 0;
                }).collect(Collectors.toList());

                if (validAces.size() != acl.size()) {
                    metricStagesWithErrorsInACL.incrementAndGet();

                    if (isAllowedToRemove) {
                        LOG.info("Removing wrong ACEs for stage: {}", stageId);
                        YpRawObjectService ypClient = clients.getYpTransactionClient().getYpRawObjectService();
                        updateStageAces(ypClient, stageId, validAces).get();
                        metricFixedStages.incrementAndGet();
                    }
                }
            }

            writer.println(String.format("Checked %d stages, %d with wrong acl's, %d fixed",
                    listOfStageMeta.size(), metricStagesWithErrorsInACL.get(), metricFixedStages.get()));
        } catch (Exception exception) {
            LOG.error("Failed to fix ACLs", exception);
        }
        finally {
            LOG.info("Processing {} {} request took: {} ms", req.getMethod(), req.getServletPath(), System.currentTimeMillis() - start);
        }
    }

    private static boolean isValidRoleSubject(PrintWriter writer,
                                              String systemName,
                                              Set<CheckType> checks, Set<Role> allRoles, Set<String> allGroups,
                                       String stageId, String projectId, String ypGroup) {
        Role role = YpServiceReadOnlyImpl.getRole(ypGroup, systemName);

        boolean result = true;

        for (CheckType check: checks) {
            switch (check) {
                case ABSENT_ROLE:
                    if (role == null || !allRoles.contains(role)) {
                        addWrongSubjectToOutput(writer, "role was not found", projectId, stageId, ypGroup);
                        result = false;
                    }
                    break;
                case MISSED_GROUP:
                    if (!allGroups.contains(ypGroup)) {
                        addWrongSubjectToOutput(writer, "missed yp group", projectId, stageId, ypGroup);
                        result = false;
                    }
                    break;
                case WRONG_PROJECT_OR_STAGE:
                    if (role == null) {
                        break;
                    }
                    if (!role.getLevelName(0).orElse("").equals(projectId)) {
                        addWrongSubjectToOutput(writer, "wrong project name in ACE", projectId, stageId, ypGroup);

                        result = false;
                    }
                    if (!role.getLevelName(1).orElse("").equals(stageId)) {
                        addWrongSubjectToOutput(writer, "wrong stage name in ACE", projectId, stageId, ypGroup);
                        result = false;
                    }
                    break;
            }
        }

        return result;
    }

    private static void addWrongSubjectToOutput(PrintWriter writer, String reason, String projectId, String stageId, String ypGroup) {
        writer.println("project = " + projectId + ", stage = " + stageId + ", YP group = " + ypGroup + ", fail reason = " + reason);
    }

    private CompletableFuture<Void> updateStageAces(YpRawObjectService ypClient,
                                                    String stageId,
                                                    List<AccessControl.TAccessControlEntry> acl) {
        YpObjectUpdate.Builder updateBuilder = YpObjectUpdate.builder(new YpTypedId(stageId, YpObjectType.STAGE));

        YTreeBuilder aclNodeBuilder = new YTreeBuilder();
        aclNodeBuilder.beginList();
        acl.forEach(entry -> aclNodeBuilder.value(YTreeProtoUtils.marshal(entry, true)));
        aclNodeBuilder.endList();

        updateBuilder.addSetUpdate(new YpSetUpdate(Paths.ACL, aclNodeBuilder.build(), YsonUtils::toYsonPayload));
        return ypClient.updateObject(updateBuilder.build());
    }
}
