package ru.yandex.grpc.utils;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiConsumer;

import javax.annotation.ParametersAreNonnullByDefault;

import io.grpc.Attributes;
import io.grpc.EquivalentAddressGroup;
import io.grpc.NameResolver;
import io.grpc.NameResolverRegistry;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.GrpcUtil;
import io.netty.channel.EventLoopGroup;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ru.yandex.monlib.metrics.registry.MetricRegistry;
import ru.yandex.solomon.util.NettyUtils;

import static java.util.concurrent.TimeUnit.MINUTES;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.stream.Collectors.toList;
import static org.junit.Assert.assertEquals;
import static ru.yandex.solomon.util.NettyUtils.clientDatagramChannelForEventLoop;

/**
 * @author Stanislav Kashirin
 */
public class NettyDnsNameResolverTest {

    private static final Logger logger = LoggerFactory.getLogger(NettyDnsNameResolverTest.class);

    private EventLoopGroup elg;
    private DnsNameResolver nettyResolver;

    private NameResolver.Args args;
    private NameResolver.Factory factory;

    @Before
    public void setUp() throws Exception {
        elg = NettyUtils.createEventLoopGroup("test-elg", 4);
        nettyResolver = new DnsNameResolverBuilder(elg.next())
            .channelType(clientDatagramChannelForEventLoop(elg))
            .build();
        args = NameResolver.Args.newBuilder()
            .setDefaultPort(42)
            .setProxyDetector(GrpcUtil.DEFAULT_PROXY_DETECTOR)
            .setSynchronizationContext(new SynchronizationContext((t, e) -> logger.error("thread: {}", t, e)))
            .setServiceConfigParser(new NameResolver.ServiceConfigParser() {
                @Override
                public NameResolver.ConfigOrError parseServiceConfig(Map<String, ?> rawServiceConfig) {
                    throw new UnsupportedOperationException();
                }
            })
            .build();
        factory = NettyDnsNameResolver.newFactory(nettyResolver, elg, new MetricRegistry());
    }

    @After
    public void tearDown() throws Exception {
        nettyResolver.close();
        elg.shutdownGracefully(0, 1, SECONDS).get(1, MINUTES);
    }

    @Test
    public void unknownHost() throws Exception {
        testNotOk(URI.create("dns:///google.yandex-team.ru"));
    }

    @Test
    public void internalHost() throws Exception {
        testOk(URI.create("dns:///a.yandex-team.ru"));
    }

    @Test
    public void externalHost() throws Exception {
        testOk(URI.create("dns:///ya.ru"));
    }

    @Test
    public void withPort() throws Exception {
        testOk(URI.create("dns:///a.yandex-team.ru:8081"));
    }

    private void testOk(URI targetUri) throws Exception {
        test(targetUri, (grpcResult, nettyResult) -> {
            assertEquals(nettyResult.toString(), Ok.class, nettyResult.getClass());
            assertEquals(grpcResult.toString(), Ok.class, grpcResult.getClass());

            var grpcGroups = normalizeGroups(((Ok) grpcResult).groups());
            var nettyGroups = normalizeGroups(((Ok) nettyResult).groups());
            assertEquals(grpcGroups, nettyGroups);
        });
    }

    private void testNotOk(URI targetUri) throws Exception {
        test(targetUri, (grpcResult, nettyResult) -> {
            assertEquals(nettyResult.toString(), NotOk.class, nettyResult.getClass());
            assertEquals(grpcResult.toString(), NotOk.class, grpcResult.getClass());
        });
    }

    private void test(URI targetUri, BiConsumer<Result, Result> assertions) throws Exception {
        var nettyListener = new TestListener();
        var grpcListener = new TestListener();

        factory.newNameResolver(targetUri, args)
            .start(nettyListener);

        NameResolverRegistry.getDefaultRegistry().asFactory()
            .newNameResolver(targetUri, args)
            .start(grpcListener);

        var nettyResult = nettyListener.result.get(1, MINUTES);
        var grpcResult = grpcListener.result.get(1, MINUTES);

        assertions.accept(grpcResult, nettyResult);
    }

    private static List<EquivalentAddressGroup> normalizeGroups(List<EquivalentAddressGroup> groups) {
        return groups.stream()
            .map(g -> new EquivalentAddressGroup(normalizeAddresses(g.getAddresses())))
            .collect(toList());
    }

    private static List<SocketAddress> normalizeAddresses(List<SocketAddress> addresses) {
        return addresses.stream()
            .map(InetSocketAddress.class::cast)
            .sorted(Comparator.comparing(a -> a.getAddress().getHostAddress().toLowerCase()))
            .collect(toList());
    }

    @ParametersAreNonnullByDefault
    private static class TestListener implements NameResolver.Listener {

        final CompletableFuture<Result> result = new CompletableFuture<>();

        @Override
        public void onAddresses(List<EquivalentAddressGroup> servers, Attributes attributes) {
            result.complete(new Ok(servers));
        }

        @Override
        public void onError(Status error) {
            result.complete(new NotOk(error));
        }
    }

    private sealed interface Result {}

    @ParametersAreNonnullByDefault
    private record Ok(List<EquivalentAddressGroup> groups) implements Result {}

    @ParametersAreNonnullByDefault
    private record NotOk(Status status) implements Result {}

}
