package httpresolver

import (
	"context"
	"encoding/json"
	"errors"
	"io/ioutil"
	"net"
	"net/http"
	"net/http/httptest"
	"os"
	"strings"
	"testing"
	"time"

	"github.com/go-resty/resty/v2"
	"github.com/google/go-cmp/cmp"
	"github.com/google/go-cmp/cmp/cmpopts"
	"github.com/stretchr/testify/assert"

	pb "a.yandex-team.ru/infra/yp_service_discovery/api"
	"a.yandex-team.ru/infra/yp_service_discovery/golang/resolver"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/library/go/httputil/headers"
)

func TestNew(t *testing.T) {
	r, err := New()
	assert.NoError(t, err)

	expected := &Resolver{
		serviceURI: "http://" + resolver.ServiceDiscoveryHostProd + ":" + resolver.ServiceDiscoveryHTTPPort,
		clientName: getClientName(),
		logger:     new(nop.Logger),
		httpc: resty.New().
			SetHostURL(r.serviceURI).
			SetLogger(r.logger.Fmt()),
	}

	opts := cmp.Options{
		cmp.AllowUnexported(Resolver{}),
		cmpopts.IgnoreUnexported(resty.Client{}),
		cmpopts.IgnoreFields(resty.Client{}, "JSONMarshal", "JSONUnmarshal", "XMLMarshal", "XMLUnmarshal"),
		cmpopts.IgnoreTypes(ResolveEndpointsRUIDFunc(nil)),
		cmpopts.IgnoreTypes(ResolvePodsRUIDFunc(nil)),
		cmpopts.IgnoreTypes(ResolveNodeRUIDFunc(nil)),
	}

	assert.True(t, cmp.Equal(expected, r, opts...), cmp.Diff(expected, r, opts...))
	assert.NotNil(t, r.resolveEndpointsRUIDFunc)
}

func TestResolver_Close(t *testing.T) {
	r, err := New()
	assert.NoError(t, err)
	err = r.Close()
	assert.NoError(t, err)
	assert.Nil(t, r.httpc)
}

func Test_getClientName(t *testing.T) {
	hostname, err := os.Hostname()
	assert.NoError(t, err)

	testCases := []struct {
		name      string
		bootstrap func()
		expected  string
	}{
		{
			name: "unknown_user",
			bootstrap: func() {
				_ = os.Unsetenv("SUDO_USER")
				_ = os.Unsetenv("USER")
			},
			expected: "go_resolver@" + hostname,
		},
		{
			name: "sudo_user_env",
			bootstrap: func() {
				_ = os.Setenv("SUDO_USER", "volozh")
				_ = os.Unsetenv("USER")
			},
			expected: "volozh@" + hostname,
		},
		{
			name: "user_env",
			bootstrap: func() {
				_ = os.Unsetenv("SUDO_USER")
				_ = os.Setenv("USER", "volozh")
			},
			expected: "volozh@" + hostname,
		},
	}

	for _, tc := range testCases {
		tc.bootstrap()
		assert.Equal(t, tc.expected, getClientName())
	}
}

func TestResolver_ResolveEndpoints(t *testing.T) {
	now := time.Now()

	testCases := []struct {
		name        string
		bootstrap   func() *httptest.Server
		expected    *resolver.ResolveEndpointsResponse
		expectedErr errChecker
	}{
		{
			name: "net_error",
			bootstrap: func() *httptest.Server {
				var srv *httptest.Server
				srv = httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
					assert.Equal(t, "/resolve_endpoints/json", r.RequestURI)
					srv.CloseClientConnections() // close connection to provoke error
				}))
				return srv
			},
			expected:    nil,
			expectedErr: isNetError(),
		},
		{
			name: "non-2xx response",
			bootstrap: func() *httptest.Server {
				srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					assert.Equal(t, "/resolve_endpoints/json", r.RequestURI)

					expectedBody, err := json.Marshal(pb.TReqResolveEndpoints{
						ClusterName:   "sas",
						EndpointSetId: "go_resolver.test",
						ClientName:    "shimba",
						Ruid:          "ololo",
					})
					assert.NoError(t, err)

					body, err := ioutil.ReadAll(r.Body)
					assert.NoError(t, err)
					assert.Equal(t, expectedBody, body)

					w.WriteHeader(http.StatusInternalServerError)
				}))
				return srv
			},
			expected:    nil,
			expectedErr: isError(errors.New("unsupported status code: 500")),
		},
		{
			name: "success",
			bootstrap: func() *httptest.Server {
				srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					assert.Equal(t, "/resolve_endpoints/json", r.RequestURI)

					expectedBody, err := json.Marshal(pb.TReqResolveEndpoints{
						ClusterName:   "sas",
						EndpointSetId: "go_resolver.test",
						ClientName:    "shimba",
						Ruid:          "ololo",
					})
					assert.NoError(t, err)

					body, err := ioutil.ReadAll(r.Body)
					assert.NoError(t, err)
					assert.Equal(t, expectedBody, body)

					resp, err := json.Marshal(pb.TRspResolveEndpoints{
						Timestamp: uint64(now.Unix()),
						EndpointSet: &pb.TEndpointSet{
							EndpointSetId: "trololo",
							Endpoints: []*pb.TEndpoint{
								{
									Id:                   "shimba-boomba",
									Protocol:             "TCP",
									Fqdn:                 "myt1-1717-msk-yp-service-discovery-20075.gencfg-c.yandex.net",
									Ip6Address:           "2a02:6b8:c08:afa8:10d:bd77:7f89:0",
									Port:                 8080,
									LabelSelectorResults: []string{"looken-tooken"},
								},
							},
						},
						ResolveStatus: pb.EResolveStatus_OK,
						WatchToken:    "0",
						Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
						Ruid:          "ololo",
					})
					assert.NoError(t, err)

					w.Header().Set(headers.ContentTypeKey, headers.TypeApplicationJSON.String())
					_, err = w.Write(resp)
					assert.NoError(t, err)
				}))
				return srv
			},
			expected: &resolver.ResolveEndpointsResponse{
				Timestamp: uint64(now.Unix()),
				EndpointSet: &resolver.EndpointSet{
					ID: "trololo",
					Endpoints: []*resolver.Endpoint{
						{
							ID:       "shimba-boomba",
							Protocol: "TCP",
							FQDN:     "myt1-1717-msk-yp-service-discovery-20075.gencfg-c.yandex.net",
							IPv6:     net.ParseIP("2a02:6b8:c08:afa8:10d:bd77:7f89:0"),
							Port:     8080,
							Labels:   []string{"looken-tooken"},
						},
					},
				},
				ResolveStatus: resolver.StatusEndpointOK,
				WatchToken:    "0",
				Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
				RUID:          "ololo",
			},
			expectedErr: nil,
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			srv := tc.bootstrap()
			defer srv.Close()

			r, err := New(
				WithServiceURI(srv.URL),
				WithClientName("shimba"),
				WithResolveEndpointsRUIDFunc(func(_ *pb.TReqResolveEndpoints) string {
					return "ololo"
				}),
			)
			assert.NoError(t, err)

			ctx := context.Background()
			resp, err := r.ResolveEndpoints(ctx, "sas", "go_resolver.test")

			if tc.expectedErr == nil {
				assert.NoError(t, err)
			} else {
				assert.True(t, tc.expectedErr(err), "unexpected error %T: %+v", err, err)
			}

			assert.True(t, cmp.Equal(tc.expected, resp), cmp.Diff(tc.expected, resp))
		})
	}
}

func TestResolver_ResolvePods(t *testing.T) {
	now := time.Now()

	testCases := []struct {
		name        string
		bootstrap   func() *httptest.Server
		expected    *resolver.ResolvePodsResponse
		expectedErr errChecker
	}{
		{
			name: "net_error",
			bootstrap: func() *httptest.Server {
				var srv *httptest.Server
				srv = httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
					assert.Equal(t, "/resolve_pods/json", r.RequestURI)
					srv.CloseClientConnections() // close connection to provoke error
				}))
				return srv
			},
			expected:    nil,
			expectedErr: isNetError(),
		},
		{
			name: "non-2xx response",
			bootstrap: func() *httptest.Server {
				srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					assert.Equal(t, "/resolve_pods/json", r.RequestURI)

					expectedBody, err := json.Marshal(pb.TReqResolvePods{
						ClusterName: "sas",
						PodSetId:    "go_resolver.test",
						ClientName:  "shimba",
						Ruid:        "ololo",
					})
					assert.NoError(t, err)

					body, err := ioutil.ReadAll(r.Body)
					assert.NoError(t, err)
					assert.Equal(t, expectedBody, body)

					w.WriteHeader(http.StatusInternalServerError)
				}))
				return srv
			},
			expected:    nil,
			expectedErr: isError(errors.New("unsupported status code: 500")),
		},
		{
			name: "success",
			bootstrap: func() *httptest.Server {
				srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					assert.Equal(t, "/resolve_pods/json", r.RequestURI)

					expectedBody, err := json.Marshal(pb.TReqResolvePods{
						ClusterName: "sas",
						PodSetId:    "go_resolver.test",
						ClientName:  "shimba",
						Ruid:        "ololo",
					})
					assert.NoError(t, err)

					body, err := ioutil.ReadAll(r.Body)
					assert.NoError(t, err)
					assert.Equal(t, expectedBody, body)

					resp, err := json.Marshal(pb.TRspResolvePods{
						Timestamp: uint64(now.Unix()),
						PodSet: &pb.TPodSet{
							PodSetId: "sas-yt-seneca-sas-nodes-over-yp",
							Pods: []*pb.TPod{
								{
									Id:     "sas2-9558-node-seneca-sas",
									NodeId: "sas2-9558.search.yandex.net",
									Ip6AddressAllocations: []*pb.TPod_TIP6AddressAllocation{
										{
											Address:        "2a02:6b8:fc17:78d:10d:adb5:ccb9:0",
											VlanId:         "fastbone",
											PersistentFqdn: "fb-sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
											TransientFqdn:  "fb-sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
										},
									},
									Dns: &pb.TPod_TDns{
										PersistentFqdn: "sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
										TransientFqdn:  "sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
									},
								},
							},
						},
						ResolveStatus: pb.EResolveStatus_OK,
						Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
						Ruid:          "ololo",
					})
					assert.NoError(t, err)

					w.Header().Set(headers.ContentTypeKey, headers.TypeApplicationJSON.String())
					_, err = w.Write(resp)
					assert.NoError(t, err)
				}))
				return srv
			},
			expected: &resolver.ResolvePodsResponse{
				Timestamp: uint64(now.Unix()),
				PodSet: &resolver.PodSet{
					ID: "sas-yt-seneca-sas-nodes-over-yp",
					Pods: []*resolver.Pod{
						{
							ID:     "sas2-9558-node-seneca-sas",
							NodeID: "sas2-9558.search.yandex.net",
							IP6AddressAllocations: []resolver.IP6AddressAllocation{
								{
									Address:        "2a02:6b8:fc17:78d:10d:adb5:ccb9:0",
									VLANID:         "fastbone",
									PersistentFQDN: "fb-sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
									TransientFQDN:  "fb-sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
								},
							},
							DNS: &resolver.DNS{
								PersistentFQDN: "sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
								TransientFQDN:  "sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
							},
						},
					},
				},
				ResolveStatus: resolver.StatusPodOK,
				Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
				RUID:          "ololo",
			},
			expectedErr: nil,
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			srv := tc.bootstrap()
			defer srv.Close()

			r, err := New(
				WithServiceURI(srv.URL),
				WithClientName("shimba"),
				WithResolvePodsRUIDFunc(func(_ *pb.TReqResolvePods) string {
					return "ololo"
				}),
			)
			assert.NoError(t, err)

			ctx := context.Background()
			resp, err := r.ResolvePods(ctx, "sas", "go_resolver.test")

			if tc.expectedErr == nil {
				assert.NoError(t, err)
			} else {
				assert.True(t, tc.expectedErr(err), "unexpected error %T: %+v", err, err)
			}

			assert.True(t, cmp.Equal(tc.expected, resp), cmp.Diff(tc.expected, resp))
		})
	}
}

func TestResolver_ResolveNode(t *testing.T) {
	now := time.Now()

	testCases := []struct {
		name        string
		bootstrap   func() *httptest.Server
		expected    *resolver.ResolveNodeResponse
		expectedErr errChecker
	}{
		{
			name: "net_error",
			bootstrap: func() *httptest.Server {
				var srv *httptest.Server
				srv = httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
					assert.Equal(t, "/resolve_node/json", r.RequestURI)
					srv.CloseClientConnections() // close connection to provoke error
				}))
				return srv
			},
			expected:    nil,
			expectedErr: isNetError(),
		},
		{
			name: "non-2xx response",
			bootstrap: func() *httptest.Server {
				srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					assert.Equal(t, "/resolve_node/json", r.RequestURI)

					expectedBody, err := json.Marshal(pb.TReqResolveNode{
						ClusterName: "sas",
						NodeId:      "go_resolver.test",
						ClientName:  "shimba",
						Ruid:        "ololo",
					})
					assert.NoError(t, err)

					body, err := ioutil.ReadAll(r.Body)
					assert.NoError(t, err)
					assert.Equal(t, expectedBody, body)

					w.WriteHeader(http.StatusInternalServerError)
				}))
				return srv
			},
			expected:    nil,
			expectedErr: isError(errors.New("unsupported status code: 500")),
		},
		{
			name: "success",
			bootstrap: func() *httptest.Server {
				srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					assert.Equal(t, "/resolve_node/json", r.RequestURI)

					expectedBody, err := json.Marshal(pb.TReqResolveNode{
						ClusterName: "sas",
						NodeId:      "go_resolver.test",
						ClientName:  "shimba",
						Ruid:        "ololo",
					})
					assert.NoError(t, err)

					body, err := ioutil.ReadAll(r.Body)
					assert.NoError(t, err)
					assert.Equal(t, expectedBody, body)

					resp, err := json.Marshal(pb.TRspResolveNode{
						Timestamp: uint64(now.Unix()),
						Pods: []*pb.TPod{
							{
								Id:     "sas2-9558-node-seneca-sas",
								NodeId: "sas2-9558.search.yandex.net",
								Ip6AddressAllocations: []*pb.TPod_TIP6AddressAllocation{
									{
										Address:        "2a02:6b8:fc17:78d:10d:adb5:ccb9:0",
										VlanId:         "fastbone",
										PersistentFqdn: "fb-sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
										TransientFqdn:  "fb-sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
									},
								},
								Dns: &pb.TPod_TDns{
									PersistentFqdn: "sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
									TransientFqdn:  "sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
								},
							},
						},
						ResolveStatus: pb.EResolveStatus_OK,
						Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
						Ruid:          "ololo",
					})
					assert.NoError(t, err)

					w.Header().Set(headers.ContentTypeKey, headers.TypeApplicationJSON.String())
					_, err = w.Write(resp)
					assert.NoError(t, err)
				}))
				return srv
			},
			expected: &resolver.ResolveNodeResponse{
				Timestamp: uint64(now.Unix()),
				Pods: []*resolver.Pod{
					{
						ID:     "sas2-9558-node-seneca-sas",
						NodeID: "sas2-9558.search.yandex.net",
						IP6AddressAllocations: []resolver.IP6AddressAllocation{
							{
								Address:        "2a02:6b8:fc17:78d:10d:adb5:ccb9:0",
								VLANID:         "fastbone",
								PersistentFQDN: "fb-sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
								TransientFQDN:  "fb-sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
							},
						},
						DNS: &resolver.DNS{
							PersistentFQDN: "sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
							TransientFQDN:  "sas2-9558-1.sas2-9558-node-seneca-sas.sas.yp-c.yandex.net",
						},
					},
				},
				ResolveStatus: resolver.StatusPodOK,
				Host:          "sas3-1449-8be-sas-yp-service-d-419-22443.gencfg-c.yandex.net",
				RUID:          "ololo",
			},
			expectedErr: nil,
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			srv := tc.bootstrap()
			defer srv.Close()

			r, err := New(
				WithServiceURI(srv.URL),
				WithClientName("shimba"),
				WithResolveNodeRUIDFunc(func(_ *pb.TReqResolveNode) string {
					return "ololo"
				}),
			)
			assert.NoError(t, err)

			ctx := context.Background()
			resp, err := r.ResolveNode(ctx, "sas", "go_resolver.test")

			if tc.expectedErr == nil {
				assert.NoError(t, err)
			} else {
				assert.True(t, tc.expectedErr(err), "unexpected error %T: %+v", err, err)
			}

			assert.True(t, cmp.Equal(tc.expected, resp), cmp.Diff(tc.expected, resp))
		})
	}
}

type errChecker func(error) bool

func isNetError() func(error) bool {
	return func(err error) bool {
		var netErr net.Error
		return errors.As(err, &netErr)
	}
}

func isError(target error) func(error) bool {
	return func(err error) bool {
		return errors.Is(err, target) || strings.Contains(err.Error(), target.Error())
	}
}
