package ru.yandex.kikimr.client.kv.noderesolver;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;

import com.google.common.net.HostAndPort;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;

import ru.yandex.kikimr.client.UnableToConnectException;
import ru.yandex.kikimr.client.discovery.NodeDiscovery;
import ru.yandex.kikimr.client.kv.transport.NodeFactory;
import ru.yandex.misc.actor.ActorRunner;
import ru.yandex.solomon.util.host.HostUtils;

/**
 * @author senyasdr
 */
public class KikmirV2NodeResolverImpl implements KikmirV2NodeResolver {

    private final ConcurrentMap<PathAndPartition, Node> pathAndPartitionNodeMap;
    private final NodeDiscovery discovery;
    private volatile State state;
    private final ActorRunner parentActor;
    private final NodeFactory factory;

    private volatile boolean closed;
    private volatile CompletionStage<Void> onChange;

    public KikmirV2NodeResolverImpl(NodeDiscovery discovery, Executor executor,
                                    NodeFactory factory) {
        this.discovery = discovery;
        this.factory = factory;
        this.pathAndPartitionNodeMap = new ConcurrentHashMap<>();
        Int2ObjectOpenHashMap<Node> nodeIdToNode = new Int2ObjectOpenHashMap<>(discovery.nodeIdToAddresses().int2ObjectEntrySet().stream()
                .map(e -> factory.createNode(e.getIntKey(), e.getValue()))
                .collect(Collectors.toUnmodifiableMap(Node::nodeId, e -> e, (a, b) -> a)));
        var nodes = List.copyOf(nodeIdToNode.values());
        this.state = new State(nodeIdToNode, nodes);

        this.parentActor = new ActorRunner(this::act, executor);
        subscribeOnChangeIfNecessary();
    }

    @Override
    public Node getNodeIdByPathAndPartitionIdOrRandom(String path, long partitionId) {
        var pathAndPartition = new PathAndPartition(partitionId, path);
        Node node = pathAndPartitionNodeMap.get(pathAndPartition);
        if (node != null) {
            if (node.isClosed()) {
                return invalidateNodeAndGetNew(path, partitionId);
            }
            return node;
        } else {
            return getRandomNode();
        }
    }

    Node getNodeIdByPathAndPartitionId(String path, long partitionId) {
        var pathAndPartition = new PathAndPartition(partitionId, path);
        return pathAndPartitionNodeMap.get(pathAndPartition);
    }

    @Override
    public void updateNodeId(String path, long partitionId, int newNodeId) {
        Node node = state.nodeIdToNode.get(newNodeId);
        if (node != null) {
            pathAndPartitionNodeMap.put(new PathAndPartition(partitionId, path), node);
        } else {
            discovery.forceUpdate().thenAccept((v) -> parentActor.schedule());
        }
    }

    @Override
    public Node invalidateNodeAndGetNew(String path, long partitionId) {
        pathAndPartitionNodeMap.remove(new PathAndPartition(partitionId, path));
        return getRandomNode();
    }

    public Node getRandomNode() {
        if (state.nodes.size() < 1) {
            throw new UnableToConnectException("Can't get node from empty list");
        }
        return state.nodes.get(ThreadLocalRandom.current().nextInt(state.nodes.size()));
    }

    @Override
    public Node getLocalNode() {
        String localHost = HostUtils.getFqdn();
        return state.nodes.stream()
                .filter(n -> localHost.equals(n.hostAndPort().getHost())).findFirst()
                .orElseThrow(() -> new UnableToConnectException("Can not find host with FQDN " + localHost));
    }

    CompletableFuture<Void> waitForUpdateFuture() {
        return parentActor.awaitAct();
    }

    private void act() {
        if (closed) {
            return;
        }
        subscribeOnChangeIfNecessary();
        actualizeNodes();
    }

    private void subscribeOnChangeIfNecessary() {
        var actualOnChange = discovery.onChange();
        if (onChange != actualOnChange) {
            onChange = actualOnChange;
            onChange.whenComplete((ignore, e) -> parentActor.schedule());
        }
    }

    private void actualizeNodes() {
        var addresses = discovery.nodeIdToAddresses();
        Int2ObjectOpenHashMap<Node> nodes = state.nodeIdToNode().clone();
        for (Int2ObjectMap.Entry<Node> next : state.nodeIdToNode().int2ObjectEntrySet()) {
            int nodeId = next.getIntKey();
            HostAndPort fqdn = addresses.get(nodeId);
            if (fqdn == null) {
                nodes.get(nodeId).close();
                nodes.remove(nodeId);
            } else if (!fqdn.equals(next.getValue().hostAndPort())) {
                var prev = nodes.put(nodeId, factory.createNode(nodeId, fqdn));
                prev.close();
            }
        }

        for (Int2ObjectMap.Entry<HostAndPort> next : addresses.int2ObjectEntrySet()) {
            int nodeId = next.getIntKey();
            if (!nodes.containsKey(nodeId)) {
                nodes.put(nodeId, factory.createNode(nodeId, next.getValue()));
            }
        }

        this.state = new State(nodes, nodes.values().stream().collect(Collectors.toUnmodifiableList()));
    }

    @Override
    public void close() {
        this.closed = true;
        parentActor.schedule();
    }

    private record State(Int2ObjectOpenHashMap<Node> nodeIdToNode, List<Node> nodes) {
    }

    private record PathAndPartition(long partitionId, String path) {
    }
}
