package ru.yandex.qloud.kikimr.transport;

import NACLibProto.Aclib;
import com.google.common.base.Strings;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.protobuf.InvalidProtocolBufferException;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import ru.yandex.kikimr.proto.FlatSchemeOp;
import ru.yandex.kikimr.proto.Msgbus;
import ru.yandex.kikimr.proto.TxProxy;
import ru.yandex.qloud.kikimr.scheme.TableStatistics;
import ru.yandex.qloud.kikimr.transport.acl.EAccessRights;
import ru.yandex.qloud.kikimr.transport.acl.EAccessType;
import ru.yandex.qloud.kikimr.transport.acl.EDiffType;
import ru.yandex.qloud.kikimr.transport.acl.EInheritanceType;
import org.springframework.beans.factory.annotation.Value;


import javax.annotation.Nonnull;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static com.google.common.base.MoreObjects.firstNonNull;

/**
 * @author violin
 */
@Component
public class KikimrScheme {
    private final static Logger LOG = LoggerFactory.getLogger(KikimrScheme.class);
    @Value("${env.KIKIMR_ACL_UPDATE_VERBOSE_LOGGING:false}")
    private boolean verboseLogging;

    private static final String ROOT_PATH = "Root/qloud";
    private static final List<String> ROOT_USERS = Lists.newArrayList("robot-qloud-kikimr", "robot-yadb-qld-pshr");

    private static final int MAX_SCHEME_QUERY_ATTEMPTS = 10;

    private final LoadingCache<Object, Set<SchemeNode>> schemeCache = CacheBuilder.newBuilder()
            .expireAfterWrite(30, TimeUnit.MINUTES)
            .build(new CacheLoader<Object, Set<SchemeNode>>() {
                @Override
                public Set<SchemeNode> load(Object key) throws Exception {
                    return loadKikimrSchemeNodes();
                }
            });

    private final LoadingCache<String, TableStatistics> nodeToStatisticsCache = CacheBuilder.newBuilder()
            .expireAfterWrite(20, TimeUnit.MINUTES)
            .build(new CacheLoader<String, TableStatistics>() {
                @Override
                public TableStatistics load(String key) throws Exception {
                    return loadTableStatistics(key);
                }
            });

    private final LoadingCache<String, List<String>> nodeToTables = CacheBuilder.newBuilder()
            .expireAfterWrite(20, TimeUnit.MINUTES)
            .build(new CacheLoader<String, List<String>>() {
                @Override
                public List<String> load(String key) throws Exception {
                    return loadTables(key);
                }
            });

    private final KikimrRpc kikimrRpc;

    public KikimrScheme(@Autowired KikimrRpc kikimrRpc) {
        this.kikimrRpc = kikimrRpc;
    }

    public void invalidateSchemeCache() {
        schemeCache.invalidate(this);
    }

    public Set<String> listAllQloudNodesAndTables() {
        return getNodeNamesFromCache((node) -> true);
    }

    public Set<String> listAllQloudTables() {
        return getNodeNamesFromCache((node) -> node.getType() == FlatSchemeOp.EPathType.EPathTypeTable);
    }

    public Set<String> listQloudNodesWithoutTables() {
        return getNodeNamesFromCache((node) -> node.getType() != FlatSchemeOp.EPathType.EPathTypeTable);
    }

    public void setAdminACL(@Nonnull String kikimrNode, @Nonnull Collection<String> userLogins) {
        setACL(kikimrNode, userLogins, EAccessRights.GENERIC_FULL);
    }

    public void setUserACL(@Nonnull String kikimrNode, @Nonnull Collection<String> userLogins) {
        setACL(kikimrNode, userLogins, EAccessRights.GENERIC_READ);
    }

    public TableStatistics getTableStatistics(@Nonnull String kikimrNode) {
        return nodeToStatisticsCache.getUnchecked(kikimrNode);
    }

    public List<String> getUserLoginsWithAccess(String kikimrNode, EAccessRights accessRights) {
        return loadTACEForNode(kikimrNode).stream()
                .filter(tace -> tace.getAccessRight() == accessRights.getValue())
                .map(tace -> StringUtils.substringBefore(tace.getSID(), "@"))
                .collect(Collectors.toList());
    }

    public List<String> getAvailableTablesSimpleNames(String kikimrNode) {
        return nodeToTables.getUnchecked(kikimrNode);
    }

    private void setACL(@Nonnull String kikimrNode, @Nonnull Collection<String> userLogins, @Nonnull EAccessRights accessRights) {
        if (! kikimrNode.contains("/")) {
            throw new IllegalArgumentException("kikimrNode not valid: " + kikimrNode);
        }

        Msgbus.TSchemeOperation aclOperation = createModifyACLOperation(kikimrNode, userLogins, accessRights);
        Msgbus.TResponse operationResponse = kikimrRpc.getRpcWithDeadline().schemeOperation(aclOperation);

        Msgbus.TResponse finalResponse = waitForCompletion(operationResponse);

        if (ResponseStatus.byIntStatus(finalResponse.getStatus()) != ResponseStatus.MSTATUS_OK) {
            throw new KikimrException(
                    String.format("operation status not ok : %s (%d): %s",
                            ResponseStatus.byIntStatus(finalResponse.getStatus()),
                            finalResponse.getStatus(),
                            (finalResponse.hasErrorReason() ? finalResponse.getErrorReason() : finalResponse)
                    )
            );
        }
    }

    private Set<String> getNodeNamesFromCache(Predicate<SchemeNode> predicate) {
        return firstNonNull(schemeCache.getUnchecked(this), Collections.<SchemeNode>emptySet())
                .stream()
                .filter(predicate)
                .map(SchemeNode::getName)
                .collect(Collectors.toSet());
    }

    private Set<SchemeNode> loadKikimrSchemeNodes() {
        ImmutableSet.Builder<SchemeNode> resultBuilder = ImmutableSet.builder();
        LOG.debug("starting to load nodes...");
        long startTime = System.currentTimeMillis();
        listRecursive(ROOT_PATH, resultBuilder::add);
        Set<SchemeNode> result = resultBuilder.build();
        LOG.debug("loaded {} nodes in {} ms", result.size(), System.currentTimeMillis() - startTime);
        if (verboseLogging) {
            result.forEach(node ->  {
                Uninterruptibles.sleepUninterruptibly(5, TimeUnit.MILLISECONDS);
                LOG.info("Found : " + node.getName());
            });
        }
        return result;
    }

    private void listRecursive(String node, Consumer<SchemeNode> nodeConsumer) {
        Msgbus.TResponse schemeResponse = null;

        for (int i = 0; i < MAX_SCHEME_QUERY_ATTEMPTS; i++) {
            try {
                Uninterruptibles.sleepUninterruptibly(20, TimeUnit.MILLISECONDS);
                schemeResponse = executeSchemeRequest(node);
                break;
            } catch (PathNotFoundException pnfe) {
                throw pnfe;
            } catch (Exception e) {
                if (i < MAX_SCHEME_QUERY_ATTEMPTS - 1) {
                    LOG.warn("exception by scheme request for node " + node, e);
                } else {
                    LOG.error("exception by scheme request for node " + node, e);
                    throw e;
                }
            }
        }
        if (schemeResponse == null) {
            throw new IllegalStateException("scheme response is null for node " + node);
        }

        for (FlatSchemeOp.TDirEntry dirEntry : schemeResponse.getPathDescription().getChildrenList()) {
            String nextNodeName = node + "/" + dirEntry.getName();
            nodeConsumer.accept(new SchemeNode(nextNodeName, dirEntry.getPathType()));
            if (dirEntry.getPathType() == FlatSchemeOp.EPathType.EPathTypeDir) {
                listRecursive(nextNodeName, nodeConsumer);
            }
        }
    }

    private Msgbus.TResponse executeSchemeRequest(String node) {
        Msgbus.TSchemeDescribe request = Msgbus.TSchemeDescribe.newBuilder()
                .setPath(node)
                .setSecurityToken(kikimrRpc.getRootUserToken())
                .build();
        Msgbus.TResponse schemeResponse = kikimrRpc.getRpcWithDeadline().schemeDescribe(request);
        if (schemeResponse.hasErrorReason() && schemeResponse.getErrorReason().contains("Path not found")) {
            throw new PathNotFoundException(schemeResponse.getErrorCode(), schemeResponse.getErrorReason());
        }
        if (schemeResponse.hasErrorCode() || schemeResponse.hasErrorReason()) {
            throw new KikimrException(schemeResponse.getErrorCode(), schemeResponse.getErrorReason());
        }
        if (! schemeResponse.hasPathDescription()) {
            throw new KikimrException(String.format("there is no path description for %s", node));
        }
        return schemeResponse;
    }

    @Nonnull
    private TableStatistics loadTableStatistics(String kikimrNode) {
        Msgbus.TResponse schemeResponse = executeSchemeRequest(kikimrNode);
        if (schemeResponse.hasPathDescription()) {
            FlatSchemeOp.TPathDescription pathDescription = schemeResponse.getPathDescription();
            if (pathDescription.hasTableStats() && pathDescription.hasTabletMetrics()) {
                return new TableStatistics(
                        pathDescription.getTableStats().getRowCount(),
                        pathDescription.getTabletMetrics().getStorage()
                );
            }
        }
        throw new IllegalArgumentException(String.format("no table statistics for %s", kikimrNode));
    }

    private List<String> loadTables(String kikimrNode) {
        Msgbus.TResponse schemeResponse = executeSchemeRequest(kikimrNode);
        if (schemeResponse.hasPathDescription()) {
            return schemeResponse.getPathDescription().getChildrenList().stream()
                    .map(FlatSchemeOp.TDirEntry::getName).collect(Collectors.toList());
        }
        return Collections.emptyList();
    }


    private Msgbus.TSchemeOperation createModifyACLOperation(@Nonnull String kikimrNode, @Nonnull Collection<String> userLogins, @Nonnull EAccessRights accessRights) {
        String aclParent = kikimrNode.substring(0, kikimrNode.lastIndexOf("/"));
        String aclNode = kikimrNode.substring(kikimrNode.lastIndexOf("/") + 1);

        Aclib.TDiffACL.Builder aclDiffBuilder = Aclib.TDiffACL.newBuilder();

        if (ROOT_PATH.equals(kikimrNode)) {
            Set<String> currentUsers = loadRootUsers();
            Set<String> newUsers = Sets.newHashSet(userLogins);
            newUsers.addAll(ROOT_USERS);

            Sets.difference(currentUsers, newUsers).forEach(
                    (user) -> aclDiffBuilder.addDiffACE(createDiffACE(user, EDiffType.REMOVE, accessRights))
            );
            Sets.difference(newUsers, currentUsers).forEach(
                    (user) -> aclDiffBuilder.addDiffACE(createDiffACE(user, EDiffType.ADD, accessRights))
            );
        } else {
            aclDiffBuilder.addDiffACE(Aclib.TDiffACE.newBuilder().clearACE().setDiffType(EDiffType.CLEAR.getValue()).build());
            for (String user : userLogins) {
                aclDiffBuilder.addDiffACE(createDiffACE(user, EDiffType.ADD, accessRights));
            }
        }

        FlatSchemeOp.TModifyACL tModifyACL = FlatSchemeOp.TModifyACL.newBuilder()
                .setName(aclNode)
                .setDiffACL(aclDiffBuilder.build().toByteString())
                .build();

        FlatSchemeOp.TModifyScheme modifyScheme = FlatSchemeOp.TModifyScheme.newBuilder()
                .setWorkingDir(aclParent)
                .setOperationType(FlatSchemeOp.EOperationType.ESchemeOpModifyACL)
                .setModifyACL(tModifyACL)
                .build();

        TxProxy.TTransaction transaction = TxProxy.TTransaction.newBuilder()
                .setModifyScheme(modifyScheme)
                .build();

        return Msgbus.TSchemeOperation.newBuilder()
                .setTransaction(transaction)
                .setSecurityToken(kikimrRpc.getRootUserToken())
                .build();
    }

    private Aclib.TDiffACE createDiffACE(String userLogin, EDiffType diffType, EAccessRights accessRights) {
        Aclib.TACE tace = Aclib.TACE.newBuilder()
                        .setAccessType(EAccessType.ALLOW.getValue())
                        .setAccessRight(accessRights.getValue())
                        .setSID(userLogin + "@staff")
                        .setInheritanceType(EInheritanceType.INHERIT_OBJECT.getValue() | EInheritanceType.INHERIT_CONTAINER.getValue())
                        .build();
        return Aclib.TDiffACE.newBuilder()
                        .setDiffType(diffType.getValue())
                        .setACE(tace)
                        .build();
    }

    private Set<String> loadRootUsers() {
        ImmutableSet.Builder<String> resultBuilder = ImmutableSet.builder();
        for (Aclib.TACE tace : loadTACEForNode("Root/qloud")) {
            if (EAccessType.DENY.getValue() == tace.getAccessType() && Strings.nullToEmpty(tace.getSID()).startsWith("allstaff")) {
                continue;
            }
            if (EAccessRights.GENERIC_FULL.getValue() != tace.getAccessRight()) {
                throw new IllegalStateException("unexpected acl for /Root/qloud : " + tace);
            }
            resultBuilder.add(StringUtils.substringBefore(tace.getSID(), "@"));
        }
        return resultBuilder.build();
    }

    private List<Aclib.TACE> loadTACEForNode(final String kikimrNode) {
        Msgbus.TResponse scheme = executeSchemeRequest(kikimrNode);
        try {
            return Aclib.TACL.parseFrom(scheme.getPathDescription().getSelf().getACL()).getACEList();
        } catch (InvalidProtocolBufferException e) {
            throw new IllegalStateException("loadTACEForNode failed", e);
        }
    }

    private Msgbus.TResponse waitForCompletion(Msgbus.TResponse initialResponse) {
        Msgbus.TResponse currentResponse = initialResponse;
        Msgbus.TSchemeOperationStatus operationStatusRequest = Msgbus.TSchemeOperationStatus.newBuilder()
                .setFlatTxId(initialResponse.getFlatTxId())
                .build();
        long maxWait = 600000; // 10 min
        long start = System.currentTimeMillis();
        while(ResponseStatus.byIntStatus(currentResponse.getStatus()) == ResponseStatus.MSTATUS_INPROGRESS
                && System.currentTimeMillis() - start < maxWait
        ) {
            Uninterruptibles.sleepUninterruptibly(50, TimeUnit.MILLISECONDS);
            currentResponse = kikimrRpc.getRpcWithDeadline().schemeOperationStatus(operationStatusRequest);
        }
        return currentResponse;
    }

    private static class SchemeNode {
        private final String name;
        private final FlatSchemeOp.EPathType type;

        SchemeNode(String name, FlatSchemeOp.EPathType type) {
            this.name = name;
            this.type = type;
        }

        public String getName() {
            return name;
        }

        public FlatSchemeOp.EPathType getType() {
            return type;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof SchemeNode)) {
                return false;
            }
            SchemeNode that = (SchemeNode) o;
            return Objects.equals(name, that.name) && type == that.type;
        }

        @Override
        public int hashCode() {
            return Objects.hash(name, type);
        }
    }
}
