package grpcresolver

import (
	"context"
	"errors"
	"net"
	"os"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/google/go-cmp/cmp"
	"github.com/google/go-cmp/cmp/cmpopts"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"go.uber.org/atomic"
	"golang.org/x/xerrors"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/connectivity"
	"google.golang.org/grpc/status"

	pb "a.yandex-team.ru/infra/yp_service_discovery/api"
	"a.yandex-team.ru/infra/yp_service_discovery/golang/resolver"
	"a.yandex-team.ru/library/go/core/log/nop"
)

type mockYPSDServer struct {
	MockResolveEndpoints func(context.Context, *pb.TReqResolveEndpoints) (*pb.TRspResolveEndpoints, error)
	MockResolveNode      func(context.Context, *pb.TReqResolveNode) (*pb.TRspResolveNode, error)
	MockResolvePods      func(context.Context, *pb.TReqResolvePods) (*pb.TRspResolvePods, error)
	MockPing             func(context.Context, *pb.TReqPing) (*pb.TRspPing, error)
}

func (m mockYPSDServer) ResolveEndpoints(ctx context.Context, r *pb.TReqResolveEndpoints) (*pb.TRspResolveEndpoints, error) {
	if m.MockResolveEndpoints == nil {
		return &pb.TRspResolveEndpoints{
			Timestamp:     uint64(time.Now().Unix()),
			ResolveStatus: pb.EResolveStatus_NOT_EXISTS,
			WatchToken:    "0",
		}, nil
	}
	return m.MockResolveEndpoints(ctx, r)
}

func (m mockYPSDServer) ResolveNode(ctx context.Context, r *pb.TReqResolveNode) (*pb.TRspResolveNode, error) {
	if m.MockResolveNode == nil {
		return &pb.TRspResolveNode{
			Timestamp:     uint64(time.Now().Unix()),
			ResolveStatus: pb.EResolveStatus_NOT_EXISTS,
		}, nil
	}
	return m.MockResolveNode(ctx, r)
}

func (m mockYPSDServer) ResolvePods(ctx context.Context, r *pb.TReqResolvePods) (*pb.TRspResolvePods, error) {
	if m.MockResolvePods == nil {
		return &pb.TRspResolvePods{
			Timestamp:     uint64(time.Now().Unix()),
			ResolveStatus: pb.EResolveStatus_NOT_EXISTS,
		}, nil
	}
	return m.MockResolvePods(ctx, r)
}

func (m mockYPSDServer) Ping(ctx context.Context, r *pb.TReqPing) (*pb.TRspPing, error) {
	if m.MockPing == nil {
		return &pb.TRspPing{
			Data: string("pong"),
		}, nil
	}
	return m.MockPing(ctx, r)
}

// newTestYPSDService returns short-living testing gRPC server, client and underlying connection
func newTestYPSDService(
	t *testing.T,
	ypsdSrv pb.TServiceDiscoveryServiceServer,
) (
	*grpc.Server,
	pb.TServiceDiscoveryServiceClient,
	*grpc.ClientConn,
	net.Listener,
) {
	lis, err := net.Listen("tcp", ":0")
	require.NoError(t, err)

	if ypsdSrv == nil {
		ypsdSrv = new(mockYPSDServer)
	}

	server := grpc.NewServer()
	pb.RegisterTServiceDiscoveryServiceServer(server, ypsdSrv)
	go func() {
		err = server.Serve(lis)
		require.NoError(t, err)
	}()
	// wait till server started
	time.Sleep(50 * time.Millisecond)

	grpcConn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
	require.NoError(t, err)

	grpcClient := pb.NewTServiceDiscoveryServiceClient(grpcConn)

	return server, grpcClient, grpcConn, lis
}

func TestNew(t *testing.T) {
	srv, _, conn, lis := newTestYPSDService(t, nil)
	defer lis.Close()
	defer srv.Stop()
	defer conn.Close()

	r, err := New()
	assert.NoError(t, err)

	expected := &Resolver{
		serviceURI: resolver.ServiceDiscoveryHostProd + ":" + resolver.ServiceDiscoveryGRPCPort,
		clientName: getClientName(),
		logger:     new(nop.Logger),
		dialed:     atomic.Bool{},
	}

	opts := cmp.Options{
		cmp.AllowUnexported(Resolver{}, atomic.Bool{}),
		cmpopts.IgnoreUnexported(grpc.ClientConn{}, sync.Mutex{}),
		cmpopts.IgnoreInterfaces(struct {
			pb.TServiceDiscoveryServiceClient
		}{}),
		cmpopts.IgnoreTypes(ResolveEndpointsRUIDFunc(nil)),
		cmpopts.IgnoreTypes(ResolvePodsRUIDFunc(nil)),
		cmpopts.IgnoreTypes(ResolveNodeRUIDFunc(nil)),
		cmp.Comparer(func(x, y atomic.Bool) bool {
			return x.Load() == y.Load()
		}),
	}

	assert.True(t, cmp.Equal(expected, r, opts...), cmp.Diff(expected, r, opts...))
	assert.NotNil(t, r.resolveEndpointsRUIDFunc)
}

func TestResolver_Close(t *testing.T) {
	srv, client, conn, lis := newTestYPSDService(t, nil)
	defer lis.Close()
	defer srv.Stop()
	defer conn.Close()

	r, err := New(WithGRPCClient(client, conn))
	assert.NoError(t, err)
	err = r.Close()
	assert.NoError(t, err)
	assert.Equal(t, r.grpcConn.GetState(), connectivity.Shutdown)
}

func Test_getClientName(t *testing.T) {
	hostname, err := os.Hostname()
	assert.NoError(t, err)

	testCases := []struct {
		name      string
		bootstrap func()
		expected  string
	}{
		{
			name: "unknown_user",
			bootstrap: func() {
				_ = os.Unsetenv("SUDO_USER")
				_ = os.Unsetenv("USER")
			},
			expected: "go_resolver@" + hostname,
		},
		{
			name: "sudo_user_env",
			bootstrap: func() {
				_ = os.Setenv("SUDO_USER", "volozh")
				_ = os.Unsetenv("USER")
			},
			expected: "volozh@" + hostname,
		},
		{
			name: "user_env",
			bootstrap: func() {
				_ = os.Unsetenv("SUDO_USER")
				_ = os.Setenv("USER", "volozh")
			},
			expected: "volozh@" + hostname,
		},
	}

	for _, tc := range testCases {
		tc.bootstrap()
		assert.Equal(t, tc.expected, getClientName())
	}
}

func TestResolver_ResolveEndpoints(t *testing.T) {
	now := time.Now()

	testCases := []struct {
		name        string
		server      pb.TServiceDiscoveryServiceServer
		expected    *resolver.ResolveEndpointsResponse
		expectedErr error
	}{
		{
			name: "bad_response",
			server: mockYPSDServer{
				MockResolveEndpoints: func(ctx context.Context, r *pb.TReqResolveEndpoints) (*pb.TRspResolveEndpoints, error) {
					return nil, status.Error(codes.Internal, "unexpected error")
				},
			},
			expected:    nil,
			expectedErr: errors.New("failed to resolve endpoint-set go_resolver.test@sas: rpc error: code = Internal desc = unexpected error"),
		},
		{
			name: "success",
			server: mockYPSDServer{
				MockResolveEndpoints: func(ctx context.Context, r *pb.TReqResolveEndpoints) (*pb.TRspResolveEndpoints, error) {
					resp := &pb.TRspResolveEndpoints{
						Timestamp: uint64(now.Unix()),
						EndpointSet: &pb.TEndpointSet{
							EndpointSetId: "trololo",
							Endpoints: []*pb.TEndpoint{
								{
									Id:                   "shimba-boomba",
									Protocol:             "TCP",
									Fqdn:                 "myt1-1717-msk-yp-service-discovery-20075.gencfg-c.yandex.net",
									Ip6Address:           "2a02:6b8:c08:afa8:10d:bd77:7f89:0",
									Port:                 8080,
									LabelSelectorResults: []string{"looken-tooken"},
								},
							},
						},
						ResolveStatus: pb.EResolveStatus_OK,
						WatchToken:    "0",
						Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
						Ruid:          "ololo",
					}
					return resp, nil
				},
			},
			expected: &resolver.ResolveEndpointsResponse{
				Timestamp: uint64(now.Unix()),
				EndpointSet: &resolver.EndpointSet{
					ID: "trololo",
					Endpoints: []*resolver.Endpoint{
						{
							ID:       "shimba-boomba",
							Protocol: "TCP",
							FQDN:     "myt1-1717-msk-yp-service-discovery-20075.gencfg-c.yandex.net",
							IPv6:     net.ParseIP("2a02:6b8:c08:afa8:10d:bd77:7f89:0"),
							Port:     8080,
							Labels:   []string{"looken-tooken"},
						},
					},
				},
				ResolveStatus: resolver.StatusEndpointOK,
				WatchToken:    "0",
				Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
				RUID:          "ololo",
			},
			expectedErr: nil,
		},
	}

	for _, tc := range testCases {
		for _, state := range []string{"WithServiceURI", "WithGRPCClient"} {
			t.Run(tc.name+"_"+state, func(t *testing.T) {
				srv, client, conn, lis := newTestYPSDService(t, tc.server)
				defer lis.Close()
				defer srv.Stop()
				defer conn.Close()

				opt := WithServiceURI(lis.Addr().String())
				if state == "WithGRPCClient" {
					opt = WithGRPCClient(client, conn)
				}

				r, err := New(
					opt,
					WithClientName("shimba"),
					WithResolveEndpointsRUIDFunc(func(_ *pb.TReqResolveEndpoints) string {
						return "ololo"
					}),
				)
				assert.NoError(t, err)
				defer r.Close()

				ctx := context.Background()
				resp, err := r.ResolveEndpoints(ctx, "sas", "go_resolver.test")

				if tc.expectedErr == nil {
					assert.NoError(t, err)
				} else {
					assert.True(t,
						xerrors.Is(
							err, tc.expectedErr) ||
							strings.Contains(err.Error(), tc.expectedErr.Error()),
						err.Error(),
					)
				}

				assert.True(t, cmp.Equal(tc.expected, resp), cmp.Diff(tc.expected, resp))
			})
		}
	}
}

func TestResolver_ResolvePods(t *testing.T) {
	now := time.Now()

	testCases := []struct {
		name        string
		server      pb.TServiceDiscoveryServiceServer
		expected    *resolver.ResolvePodsResponse
		expectedErr error
	}{
		{
			name: "bad_response",
			server: mockYPSDServer{
				MockResolvePods: func(ctx context.Context, r *pb.TReqResolvePods) (*pb.TRspResolvePods, error) {
					return nil, status.Error(codes.Internal, "unexpected error")
				},
			},
			expected:    nil,
			expectedErr: errors.New("failed to resolve pod-set go_resolver.test@sas: rpc error: code = Internal desc = unexpected error"),
		},
		{
			name: "success",
			server: mockYPSDServer{
				MockResolvePods: func(ctx context.Context, r *pb.TReqResolvePods) (*pb.TRspResolvePods, error) {
					resp := &pb.TRspResolvePods{
						Timestamp: uint64(now.Unix()),
						PodSet: &pb.TPodSet{
							PodSetId: "sas-yt-seneca-sas-nodes-over-yp",
							Pods: []*pb.TPod{
								{
									Id:     "sas2-9558-node-seneca-sas",
									NodeId: "sas2-9558.search.yandex.net",
									Ip6AddressAllocations: []*pb.TPod_TIP6AddressAllocation{
										{
											Address:        "2a02:6b8:fc17:78d:10d:adb5:ccb9:0",
											VlanId:         "fastbone",
											PersistentFqdn: "fb-sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
											TransientFqdn:  "fb-sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
										},
									},
									Dns: &pb.TPod_TDns{
										PersistentFqdn: "sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
										TransientFqdn:  "sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
									},
								},
							},
						},
						ResolveStatus: pb.EResolveStatus_OK,
						Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
						Ruid:          "ololo",
					}
					return resp, nil
				},
			},
			expected: &resolver.ResolvePodsResponse{
				Timestamp: uint64(now.Unix()),
				PodSet: &resolver.PodSet{
					ID: "sas-yt-seneca-sas-nodes-over-yp",
					Pods: []*resolver.Pod{
						{
							ID:     "sas2-9558-node-seneca-sas",
							NodeID: "sas2-9558.search.yandex.net",
							IP6AddressAllocations: []resolver.IP6AddressAllocation{
								{
									Address:         "2a02:6b8:fc17:78d:10d:adb5:ccb9:0",
									VLANID:          "fastbone",
									PersistentFQDN:  "fb-sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
									TransientFQDN:   "fb-sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
									VirtualServices: []resolver.VirtualService{},
								},
							},
							IP6SubnetAllocations: []resolver.IP6SubnetAllocation{},
							DNS: &resolver.DNS{
								PersistentFQDN: "sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
								TransientFQDN:  "sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
							},
						},
					},
				},
				ResolveStatus: resolver.StatusPodOK,
				Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
				RUID:          "ololo",
			},
			expectedErr: nil,
		},
	}

	for _, tc := range testCases {
		for _, state := range []string{"WithServiceURI", "WithGRPCClient"} {
			t.Run(tc.name+"_"+state, func(t *testing.T) {
				srv, client, conn, lis := newTestYPSDService(t, tc.server)
				defer lis.Close()
				defer srv.Stop()
				defer conn.Close()

				opt := WithServiceURI(lis.Addr().String())
				if state == "WithGRPCClient" {
					opt = WithGRPCClient(client, conn)
				}

				r, err := New(
					opt,
					WithClientName("shimba"),
					WithResolvePodsRUIDFunc(func(_ *pb.TReqResolvePods) string {
						return "ololo"
					}),
				)
				assert.NoError(t, err)
				defer r.Close()

				ctx := context.Background()
				resp, err := r.ResolvePods(ctx, "sas", "go_resolver.test")

				if tc.expectedErr == nil {
					assert.NoError(t, err)
				} else {
					assert.True(t,
						xerrors.Is(
							err, tc.expectedErr) ||
							strings.Contains(err.Error(), tc.expectedErr.Error()),
						err.Error(),
					)
				}

				assert.True(t, cmp.Equal(tc.expected, resp), cmp.Diff(tc.expected, resp))
			})
		}
	}
}

func TestResolver_ResolveNode(t *testing.T) {
	now := time.Now()

	testCases := []struct {
		name        string
		server      pb.TServiceDiscoveryServiceServer
		expected    *resolver.ResolveNodeResponse
		expectedErr error
	}{
		{
			name: "bad_response",
			server: mockYPSDServer{
				MockResolveNode: func(ctx context.Context, r *pb.TReqResolveNode) (*pb.TRspResolveNode, error) {
					return nil, status.Error(codes.Internal, "unexpected error")
				},
			},
			expected:    nil,
			expectedErr: errors.New("failed to resolve pod-set go_resolver.test@sas: rpc error: code = Internal desc = unexpected error"),
		},
		{
			name: "success",
			server: mockYPSDServer{
				MockResolveNode: func(ctx context.Context, r *pb.TReqResolveNode) (*pb.TRspResolveNode, error) {
					resp := &pb.TRspResolveNode{
						Timestamp: uint64(now.Unix()),
						Pods: []*pb.TPod{
							{
								Id:     "sas2-9558-node-seneca-sas",
								NodeId: "sas2-9558.search.yandex.net",
								Ip6AddressAllocations: []*pb.TPod_TIP6AddressAllocation{
									{
										Address:        "2a02:6b8:fc17:78d:10d:adb5:ccb9:0",
										VlanId:         "fastbone",
										PersistentFqdn: "fb-sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
										TransientFqdn:  "fb-sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
									},
								},
								Dns: &pb.TPod_TDns{
									PersistentFqdn: "sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
									TransientFqdn:  "sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
								},
							},
						},
						ResolveStatus: pb.EResolveStatus_OK,
						Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
						Ruid:          "ololo",
					}
					return resp, nil
				},
			},
			expected: &resolver.ResolveNodeResponse{
				Timestamp: uint64(now.Unix()),
				Pods: []*resolver.Pod{
					{
						ID:     "sas2-9558-node-seneca-sas",
						NodeID: "sas2-9558.search.yandex.net",
						IP6AddressAllocations: []resolver.IP6AddressAllocation{
							{
								Address:         "2a02:6b8:fc17:78d:10d:adb5:ccb9:0",
								VLANID:          "fastbone",
								PersistentFQDN:  "fb-sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
								TransientFQDN:   "fb-sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
								VirtualServices: []resolver.VirtualService{},
							},
						},
						IP6SubnetAllocations: []resolver.IP6SubnetAllocation{},
						DNS: &resolver.DNS{
							PersistentFQDN: "sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
							TransientFQDN:  "sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
						},
					},
				},
				ResolveStatus: resolver.StatusPodOK,
				Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
				RUID:          "ololo",
			},
			expectedErr: nil,
		},
	}

	for _, tc := range testCases {
		for _, state := range []string{"WithServiceURI", "WithGRPCClient"} {
			t.Run(tc.name+"_"+state, func(t *testing.T) {
				srv, client, conn, lis := newTestYPSDService(t, tc.server)
				defer lis.Close()
				defer srv.Stop()
				defer conn.Close()

				opt := WithServiceURI(lis.Addr().String())
				if state == "WithGRPCClient" {
					opt = WithGRPCClient(client, conn)
				}

				r, err := New(
					opt,
					WithClientName("shimba"),
					WithResolveNodeRUIDFunc(func(_ *pb.TReqResolveNode) string {
						return "ololo"
					}),
				)
				assert.NoError(t, err)
				defer r.Close()

				ctx := context.Background()
				resp, err := r.ResolveNode(ctx, "sas", "go_resolver.test")

				if tc.expectedErr == nil {
					assert.NoError(t, err)
				} else {
					assert.True(t,
						xerrors.Is(
							err, tc.expectedErr) ||
							strings.Contains(err.Error(), tc.expectedErr.Error()),
						err.Error(),
					)
				}

				assert.True(t, cmp.Equal(tc.expected, resp), cmp.Diff(tc.expected, resp))
			})
		}
	}
}

func TestResolver_dialGRPCLazyRetrySync(t *testing.T) {
	r, err := New(WithServiceURI("shimba.boomba"))
	assert.NoError(t, err)
	defer r.Close()

	for i := 0; i < 3; i++ {
		_, err = r.ResolveEndpoints(context.Background(), resolver.ClusterSAS, "ololo.trololo")
		assert.Error(t, err)
		assert.False(t, r.dialed.Load())
		assert.Nil(t, r.grpcConn)
		assert.Nil(t, r.grpcClient)
	}
}

func TestResolver_dialGRPCLazyRetryAsync(t *testing.T) {
	r, err := New(WithServiceURI("shimba.boomba"))
	assert.NoError(t, err)
	defer r.Close()

	var wg sync.WaitGroup
	wg.Add(3)
	for i := 0; i < 3; i++ {
		go func() {
			defer wg.Done()

			_, err := r.ResolveEndpoints(context.Background(), "kek", "cheburek")
			assert.Error(t, err)
			assert.False(t, r.dialed.Load())
			assert.Nil(t, r.grpcConn)
			assert.Nil(t, r.grpcClient)
		}()
	}
	wg.Wait()
}
