package apiserver

import (
	"context"
	"errors"
	"reflect"
	"testing"

	"github.com/gofrs/uuid"
	"github.com/golang/mock/gomock"
	grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"github.com/stretchr/testify/suite"
	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"

	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/library/go/httputil/headers"
	"a.yandex-team.ru/library/go/test/grpctest"
	"a.yandex-team.ru/library/go/test/grpctest/testproto"
	"a.yandex-team.ru/library/go/yandex/blackbox"
	bbmocks "a.yandex-team.ru/library/go/yandex/blackbox/mocks"
	"a.yandex-team.ru/tasklet/experimental/internal/xgrpc"
	"a.yandex-team.ru/tasklet/experimental/internal/yandex/sandbox"

	"a.yandex-team.ru/tasklet/experimental/internal/consts"
	"a.yandex-team.ru/tasklet/experimental/internal/requestctx"
	testutils "a.yandex-team.ru/tasklet/experimental/internal/test_utils"
)

func TestMiddlewareAuth(t *testing.T) {

	noAuthTests := []struct {
		name string
		md   metadata.MD
		want requestctx.AuthSubject
	}{
		{
			"no_auth_with_header",
			metadata.Pairs(xgrpc.TestUserMetadataKey, "imperator"),
			requestctx.NewUser("imperator"),
		},
		{
			"no_auth",
			metadata.Pairs(),
			requestctx.NewUser(TestDefaultUser),
		},
	}
	for _, tt := range noAuthTests {
		t.Run(
			tt.name, func(t *testing.T) {
				ctx := context.Background()
				mw, err := NewMiddleware(&MiddlewareConf{Auth: false}, &nop.Logger{}, nil)
				require.NoError(t, err)

				got, err := mw.auth(ctx, tt.md, "8.8.8.8")
				require.NoError(t, err)

				if !reflect.DeepEqual(got, tt.want) {
					t.Errorf("auth() got = %v, want %v", got, tt.want)
				}
			},
		)
	}

	authTests := []struct {
		name      string
		md        metadata.MD
		mockSetup func(*testing.T, *bbmocks.MockClient)
		want      requestctx.AuthSubject
		err       error
	}{
		{
			"token_auth",
			metadata.Pairs(headers.AuthorizationKey, oauthPrefix+" blah_token "),
			func(tt *testing.T, bb *bbmocks.MockClient) {
				bb.EXPECT().
					OAuth(gomock.Any(), gomock.Any()).
					DoAndReturn(
						func(ctx context.Context, request blackbox.OAuthRequest) (*blackbox.OAuthResponse, error) {
							require.Equal(tt, request.OAuthToken, "blah_token")
							require.Equal(tt, request.UserIP, "8.8.8.8")
							return &blackbox.OAuthResponse{
								User: blackbox.User{Login: "imperator"},
							}, nil
						},
					).
					Times(1)
			},
			requestctx.NewUser("imperator"),
			nil,
		},
		{
			"bb_fail",
			metadata.Pairs(headers.AuthorizationKey, oauthPrefix+"blah_token"),
			func(tt *testing.T, bb *bbmocks.MockClient) {
				bb.EXPECT().
					OAuth(gomock.Any(), gomock.Any()).
					Return(nil, errors.New("OBLOM ERROR")).
					Times(1)
			},
			requestctx.NewInvalid(),
			stAuthError.Err(),
		},
		{
			"bad_token",
			metadata.Pairs(headers.AuthorizationKey, oauthPrefix+"blah_token"),
			func(tt *testing.T, bb *bbmocks.MockClient) {
				bb.EXPECT().
					OAuth(gomock.Any(), gomock.Any()).
					Return(nil, &blackbox.UnauthorizedError{}).
					Times(1)
			},
			requestctx.NewInvalid(),
			stAuthFailed.Err(),
		},
		{
			"bad_header",
			metadata.Pairs(headers.AuthorizationKey, "ZZAuth BAD BAD BAD"),
			func(tt *testing.T, bb *bbmocks.MockClient) {
				bb.EXPECT().
					OAuth(gomock.Any(), gomock.Any()).
					Times(0)
			},
			requestctx.NewInvalid(),
			stAuthError.Err(),
		},
		{
			"proxyfied_token_auth",
			metadata.Pairs(
				GrpcGatewayHeaderPrefix+headers.AuthorizationKey, oauthPrefix+" blah_token ",
				consts.ForwardedForHeader, "9.9.9.9",
				consts.ForwardedForHeader, "7.7.7.7",
			),
			func(tt *testing.T, bb *bbmocks.MockClient) {
				bb.EXPECT().
					OAuth(gomock.Any(), gomock.Any()).
					DoAndReturn(
						func(ctx context.Context, request blackbox.OAuthRequest) (*blackbox.OAuthResponse, error) {
							require.Equal(tt, "blah_token", request.OAuthToken)
							require.Equal(tt, "9.9.9.9", request.UserIP)
							return &blackbox.OAuthResponse{
								User: blackbox.User{Login: "imperator"},
							}, nil
						},
					).
					Times(1)
			},
			requestctx.NewUser("imperator"),
			nil,
		},
		{
			"proxyfied_token_auth_no_forwarded_ip",
			metadata.Pairs(
				GrpcGatewayHeaderPrefix+headers.AuthorizationKey, oauthPrefix+" blah_token ",
			),
			func(tt *testing.T, bb *bbmocks.MockClient) {
				bb.EXPECT().
					OAuth(gomock.Any(), gomock.Any()).
					DoAndReturn(
						func(ctx context.Context, request blackbox.OAuthRequest) (*blackbox.OAuthResponse, error) {
							require.Equal(tt, request.OAuthToken, "blah_token")
							require.Equal(tt, request.UserIP, "8.8.8.8")
							return &blackbox.OAuthResponse{
								User: blackbox.User{Login: "imperator"},
							}, nil
						},
					).
					Times(1)
			},
			requestctx.NewUser("imperator"),
			nil,
		},
		{
			"session_id_auth",
			metadata.Pairs(
				"x-forwarded-host", "badums",
				"grpcgateway-cookie", "yandexuid=1543; Session_id=secret; foo=bar",
				consts.ForwardedForHeader, "9.9.9.9",
			),
			func(tt *testing.T, bb *bbmocks.MockClient) {
				bb.EXPECT().
					SessionID(gomock.Any(), gomock.Any()).
					DoAndReturn(
						func(ctx context.Context, request blackbox.SessionIDRequest) (
							*blackbox.SessionIDResponse,
							error,
						) {
							require.Equal(tt, request.SessionID, "secret")
							require.Equal(tt, request.UserIP, "9.9.9.9")
							require.Equal(tt, request.Host, "badums")
							return &blackbox.SessionIDResponse{
								User: blackbox.User{Login: "benny_hill"},
							}, nil
						},
					).
					Times(1)
			},
			requestctx.NewUser("benny_hill"),
			nil,
		},
		{
			"session_id_auth_multi_forward",
			metadata.Pairs(
				"x-forwarded-host", "badums",
				"grpcgateway-cookie", "yandexuid=1543; Session_id=secret; foo=bar",
				consts.ForwardedForHeader, "9.9.9.9, 1.1.1.1",
				consts.ForwardedForHeader, "3.3.3.3",
			),
			func(tt *testing.T, bb *bbmocks.MockClient) {
				bb.EXPECT().
					SessionID(gomock.Any(), gomock.Any()).
					DoAndReturn(
						func(ctx context.Context, request blackbox.SessionIDRequest) (
							*blackbox.SessionIDResponse,
							error,
						) {
							require.Equal(tt, "secret", request.SessionID)
							require.Equal(tt, "9.9.9.9", request.UserIP)
							require.Equal(tt, "badums", request.Host)
							return &blackbox.SessionIDResponse{
								User: blackbox.User{Login: "benny_hill"},
							}, nil
						},
					).
					Times(1)
			},
			requestctx.NewUser("benny_hill"),
			nil,
		},
		{
			"bad_host_header",
			metadata.Pairs(
				consts.ForwardedHostHeader, "[2a02:6b8:c14:6c9e:0:522:b3b2:1]",
				"grpcgateway-cookie", "yandexuid=1543; Session_id=secret; foo=bar",
				consts.ForwardedForHeader, "9.9.9.9",
			),
			func(tt *testing.T, bb *bbmocks.MockClient) {
				bb.EXPECT().
					SessionID(gomock.Any(), gomock.Any()).
					DoAndReturn(
						func(ctx context.Context, request blackbox.SessionIDRequest) (
							*blackbox.SessionIDResponse,
							error,
						) {
							require.Equal(tt, request.SessionID, "secret")
							require.Equal(tt, request.UserIP, "9.9.9.9")
							require.Equal(tt, request.Host, "tasklets.in.yandex-team.ru")
							return &blackbox.SessionIDResponse{
								User: blackbox.User{Login: "benny_hill"},
							}, nil
						},
					).
					Times(1)
			},
			requestctx.NewUser("benny_hill"),
			nil,
		},
	}
	for _, testCase := range authTests {
		t.Run(
			testCase.name, func(t *testing.T) {
				mw, err := NewMiddleware(&MiddlewareConf{Auth: false}, &nop.Logger{}, nil)
				require.NoError(t, err)

				ctx := context.Background()
				ctrl := gomock.NewController(t)
				defer ctrl.Finish()

				bbCli := bbmocks.NewMockClient(ctrl)
				mw.conf.Auth = true
				mw.bb = bbCli
				testCase.mockSetup(t, bbCli)
				got, err := mw.auth(ctx, testCase.md, "8.8.8.8")

				if err != nil {
					if testCase.err != nil {
						require.Equal(t, err.Error(), testCase.err.Error())
					} else {
						require.NoError(t, err)
					}
				}
				if !reflect.DeepEqual(got, testCase.want) {
					t.Errorf("auth() got = %v, want %v", got, testCase.want)
				}
			},
		)
	}
}

type NopSandboxAuth struct {
	sandboxSessionsOK          map[sandbox.SandboxSession]sandbox.SessionInfo
	sandboxSessionsErr         map[sandbox.SandboxSession]error
	externalSandboxSessionsOK  map[sandbox.SandboxExternalSession]sandbox.ExternalSessionInfo
	externalSandboxSessionsErr map[sandbox.SandboxExternalSession]error
}

func (n *NopSandboxAuth) CheckSandboxSession(_ context.Context, session sandbox.SandboxSession) (
	sandbox.SessionInfo,
	error,
) {
	rv, ok := n.sandboxSessionsOK[session]
	if ok {
		return rv, nil
	}

	err, ok := n.sandboxSessionsErr[session]
	if ok {
		return sandbox.SessionInfo{}, err
	}
	panic(ok)
}

func (n *NopSandboxAuth) CheckExternalSession(_ context.Context, session sandbox.SandboxExternalSession) (
	sandbox.ExternalSessionInfo,
	error,
) {
	rv, ok := n.externalSandboxSessionsOK[session]
	if ok {
		return rv, nil
	}

	err, ok := n.externalSandboxSessionsErr[session]
	if ok {
		return sandbox.ExternalSessionInfo{}, err
	}
	panic(ok)
}

func TestMiddlewareAuthSandbox(t *testing.T) {
	authTests := []struct {
		name    string
		md      metadata.MD
		checker SandboxSessionChecker
		want    requestctx.AuthSubject
		err     error
	}{
		{
			"sandbox_ok",
			metadata.Pairs(headers.AuthorizationKey, sandboxSessionPrefix+" blah"),
			&NopSandboxAuth{
				sandboxSessionsOK: map[sandbox.SandboxSession]sandbox.SessionInfo{
					"blah": {
						Token:  "blah",
						Login:  "imperator",
						TaskID: 1543,
					},
				},
			},
			requestctx.NewSandboxTask(1543),
			nil,
		},
		{
			"sandbox_err",
			metadata.Pairs(headers.AuthorizationKey, sandboxSessionPrefix+" blah"),
			&NopSandboxAuth{
				sandboxSessionsErr: map[sandbox.SandboxSession]error{
					"blah": sandbox.ErrSandboxNotFound,
				},
			},
			requestctx.NewInvalid(),
			stAuthFailed.Err(),
		},
		{
			"sandbox_unexpected",
			metadata.Pairs(headers.AuthorizationKey, sandboxSessionPrefix+" blah"),
			&NopSandboxAuth{
				sandboxSessionsErr: map[sandbox.SandboxSession]error{
					"blah": errors.New("unexpected error"),
				},
			},
			requestctx.NewSandboxTask(1543),
			stAuthError.Err(),
		},
		{
			"external_ok",
			metadata.Pairs(headers.AuthorizationKey, sandboxExternalSessionPrefix+" blah"),
			&NopSandboxAuth{
				externalSandboxSessionsOK: map[sandbox.SandboxExternalSession]sandbox.ExternalSessionInfo{
					"blah": {
						Token:       "blah",
						ExecutionID: "fcd4c272-e016-11ec-bc18-2be5e755e8d7",
						TaskID:      1543,
					},
				},
			},
			requestctx.NewExecutionID("fcd4c272-e016-11ec-bc18-2be5e755e8d7"),
			nil,
		},
		{
			"external_fail",
			metadata.Pairs(headers.AuthorizationKey, sandboxExternalSessionPrefix+" blah"),
			&NopSandboxAuth{
				externalSandboxSessionsErr: map[sandbox.SandboxExternalSession]error{
					"blah": sandbox.ErrSandboxNotFound,
				},
			},
			requestctx.NewInvalid(),
			stAuthFailed.Err(),
		},
		{
			"external_err",
			metadata.Pairs(headers.AuthorizationKey, sandboxExternalSessionPrefix+" blah"),
			&NopSandboxAuth{
				externalSandboxSessionsErr: map[sandbox.SandboxExternalSession]error{
					"blah": errors.New("oblom error"),
				},
			},
			requestctx.NewInvalid(),
			stAuthError.Err(),
		},
	}
	for _, testCase := range authTests {
		t.Run(
			testCase.name, func(t *testing.T) {
				mw, err := NewMiddleware(&MiddlewareConf{Auth: false}, &nop.Logger{}, testCase.checker)
				require.NoError(t, err)
				ctx := context.Background()
				mw.conf.Auth = true

				got, err := mw.auth(ctx, testCase.md, "8.8.8.8")

				if err == nil {
					assert.EqualValues(t, got, testCase.want)
				} else if testCase.err != nil {
					assert.Equal(t, err.Error(), testCase.err.Error())
				} else {
					assert.NoError(t, err)
				}
			},
		)
	}
}

func TestFeatureFlags(t *testing.T) {

	authTests := []struct {
		name      string
		md        metadata.MD
		checkFunc func(*testing.T, context.Context)
	}{
		{
			"simple_feature",
			metadata.Pairs(TaskletFeaturePrefix+"feat1", "{\"foo\": \"bar\"}"),
			func(tt *testing.T, ctx context.Context) {
				feature := requestctx.GetFeature(ctx, "feat1")
				require.NotNil(tt, feature)

				want := map[string]interface{}{"foo": "bar"}
				require.Equal(tt, want, feature)
			},
		},
		{
			"multy_feature",
			metadata.Pairs(
				TaskletFeaturePrefix+"feat1", "\"good\"",
				TaskletFeaturePrefix+"feat2", "1543",
				TaskletFeaturePrefix+"feat3", "[1543, 57]",
				TaskletFeaturePrefix+"feat4", "true",
			),
			func(tt *testing.T, ctx context.Context) {
				{
					feature := requestctx.GetFeature(ctx, "feat1")
					require.NotNil(tt, feature)
					want := "good"
					require.Equal(tt, want, feature)
				}
				{
					feature := requestctx.GetFeature(ctx, "feat2")
					require.NotNil(tt, feature)
					want := float64(1543)
					require.Equal(tt, want, feature)
				}
				{
					feature := requestctx.GetFeature(ctx, "feat3")
					require.NotNil(tt, feature)
					want := []interface{}{float64(1543), float64(57)}
					require.Equal(tt, want, feature)
				}
				{
					feature := requestctx.GetFeature(ctx, "feat4")
					require.NotNil(tt, feature)
					require.Equal(tt, true, feature)
				}

			},
		},
		{
			"invalid_features",
			metadata.Pairs(
				TaskletFeaturePrefix+"feat1", "{\"foo\": inv,alid}",
				TaskletFeaturePrefix+"feat2", "{1543}",
				TaskletFeaturePrefix+"feat3", "bad",
				TaskletFeaturePrefix+"feat4", "{\"a\": \"b\"}",
				TaskletFeaturePrefix+"feat5", "{1543zzz}",
			),
			func(tt *testing.T, ctx context.Context) {
				for _, name := range []string{"feat1", "feat2", "feat3", "feat5"} {
					feature := requestctx.GetFeature(ctx, name)
					require.Nil(tt, feature)
				}

				feature := requestctx.GetFeature(ctx, "feat4")
				want := map[string]interface{}{"a": "b"}
				require.Equal(tt, want, feature)
			},
		},
		{
			"complex_feature",
			metadata.Pairs(
				TaskletFeaturePrefix+"feat1",
				"{\"foo\": {\"bar\": \"baz\"}, \"z\": \"x\"}",
			),
			func(tt *testing.T, ctx context.Context) {
				feature := requestctx.GetFeature(ctx, "feat1")
				require.NotNil(tt, feature)

				want := map[string]interface{}{
					"foo": map[string]interface{}{"bar": "baz"},
					"z":   "x",
				}
				require.Equal(tt, want, feature)
			},
		},
	}
	for _, testCase := range authTests {
		t.Run(
			testCase.name, func(t *testing.T) {
				mw, err := NewMiddleware(&MiddlewareConf{Auth: false}, &nop.Logger{}, nil)
				require.NoError(t, err)

				ctx := context.Background()
				newCtx := mw.setFeatureFlags(ctx, testCase.md)
				testCase.checkFunc(t, newCtx)
			},
		)
	}
}

var (
	pingRequest = &testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
)

type RequestIDServerTestSuite struct {
	*grpctest.InterceptorTestSuite
}

func (suite *RequestIDServerTestSuite) TestUnaryNoReqID() {
	var header metadata.MD
	pingResponse, err := suite.Client.Ping(context.Background(), pingRequest, grpc.Header(&header))
	suite.Require().NoError(err)
	suite.Require().Equal(pingRequest.Value, pingResponse.Value)
	suite.Require().Len(header.Get(consts.RequestIDHeader), 1)
	suite.Require().NotPanics(func() { uuid.Must(uuid.FromString(header.Get(consts.RequestIDHeader)[0])) })
}

func (suite *RequestIDServerTestSuite) TestUnaryWithReqID() {
	requestID := uuid.Must(uuid.NewV4()).String()
	var header metadata.MD
	pingResponse, err := suite.Client.Ping(
		metadata.NewOutgoingContext(
			context.Background(),
			metadata.Pairs(consts.RequestIDHeader, requestID),
		),
		pingRequest,
		grpc.Header(&header),
	)
	suite.Require().NoError(err)
	suite.Require().Equal(pingRequest.Value, pingResponse.Value)
	suite.Require().Len(header.Get(consts.RequestIDHeader), 1)
	suite.Require().Equal(header.Get(consts.RequestIDHeader)[0], requestID)
}

func TestRequestIDTestSuite(t *testing.T) {
	// NB: Requires bundled certificate resource
	testutils.EnsureArcadiaTest(t)
	l := testutils.MakeLogger("/dev/null")
	s := &RequestIDServerTestSuite{
		InterceptorTestSuite: &grpctest.InterceptorTestSuite{
			ServerOpts: []grpc.ServerOption{
				grpc.UnaryInterceptor(
					grpcmiddleware.ChainUnaryServer(
						UnaryRequestIDGenerator(l.WithName("request_id_middleware")),
					),
				),
			},
		},
	}
	suite.Run(t, s)
}
