package cfnversion

import (
	"context"
	"testing"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/service/iam"
	"github.com/aws/aws-sdk-go/service/iam/iamiface"
	"github.com/pkg/errors"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
)

type mockIAM struct {
	mock.Mock
	iamiface.IAMAPI
}

func (m *mockIAM) ListRoleTagsWithContext(ctx context.Context, in *iam.ListRoleTagsInput, opt ...request.Option) (*iam.ListRoleTagsOutput, error) {
	optiface := make([]interface{}, len(opt))
	for i := range opt {
		optiface[i] = opt[i]
	}
	var callArgs []interface{}
	callArgs = append(callArgs, ctx, in)
	callArgs = append(callArgs, optiface...)
	ret := m.Called(callArgs...)
	if ret.Get(0) == nil {
		return nil, ret.Error(1)
	}
	return ret.Get(0).(*iam.ListRoleTagsOutput), ret.Error(1)

}

func TestMain(m *testing.M) {
	cache = versionCache{
		skipCache: true,
	}
	m.Run()
}

func TestVersion(t *testing.T) {
	tts := []struct {
		v1Major        int
		v1Minor        int
		v1Patch        int
		v2Major        int
		v2Minor        int
		v2Patch        int
		v1BiggerThanV2 bool
	}{
		{0, 0, 0, 0, 0, 0, true},
		{1, 2, 3, 0, 0, 0, true},
		{0, 0, 0, 1, 2, 3, false},
		{1, 1, 1, 1, 1, 2, false},
		{1, 1, 1, 1, 2, 1, false},
		{1, 1, 1, 2, 1, 1, false},
		{1, 1, 2, 1, 1, 1, true},
		{1, 2, 1, 1, 1, 1, true},
		{2, 1, 1, 1, 1, 1, true},
	}

	for _, tt := range tts {
		v1 := &version{
			major: tt.v1Major,
			minor: tt.v1Minor,
			patch: tt.v1Patch,
		}
		v2 := &version{
			major: tt.v2Major,
			minor: tt.v2Minor,
			patch: tt.v2Patch,
		}
		assert.Equal(t, tt.v1BiggerThanV2, v1.isGreaterThanOrEqual(v2))
	}
}

func TestParse(t *testing.T) {
	tts := []struct {
		s             string
		parsedMajor   int
		parsedMinor   int
		parsedPatch   int
		expectedError bool
	}{
		{"v1.3.0", 1, 3, 0, false},
		{"v10.11.12", 10, 11, 12, false},
		{"v100.011.222", 100, 11, 222, false},
		{"v0.0.0", 0, 0, 0, false},
		{"v.0.0", -1, -1, -1, true},
		{"v1.1", -1, -1, -1, true},
		{"1.3.0", -1, -1, -1, true},
		{"v1.3.0-beta", -1, -1, -1, true},
		{"totally-wrong", -1, -1, -1, true},
	}

	for _, tt := range tts {
		v, err := parse(tt.s)
		if tt.expectedError {
			assert.Error(t, err)
		} else {
			assert.Equal(t, tt.parsedMajor, v.major)
			assert.Equal(t, tt.parsedMinor, v.minor)
			assert.Equal(t, tt.parsedPatch, v.patch)
		}
	}
}

func TestGet(t *testing.T) {
	ctx := context.Background()
	t.Run("HappyPath", func(t *testing.T) {
		mockClient := &mockIAM{}
		mockResp := &iam.ListRoleTagsOutput{
			Tags: []*iam.Tag{
				{
					Key:   aws.String("Version"),
					Value: aws.String("v1.3.0"),
				},
			},
		}
		mockClient.On("ListRoleTagsWithContext", mock.Anything, mock.Anything).Return(mockResp, nil)
		v, err := cache.getWithClient(ctx, mockClient)
		require.NoError(t, err)
		assert.Equal(t, 1, v.major)
		assert.Equal(t, 3, v.minor)
		assert.Equal(t, 0, v.patch)
	})

	t.Run("IAMError", func(t *testing.T) {
		mockClient := &mockIAM{}
		mockClient.On("ListRoleTagsWithContext", mock.Anything, mock.Anything).Return(nil, errors.New("oh no something bad"))
		v, err := cache.getWithClient(ctx, mockClient)
		assert.Error(t, err)
		assert.Nil(t, v)
	})

	t.Run("NoVersionTag", func(t *testing.T) {
		mockClient := &mockIAM{}
		mockResp := &iam.ListRoleTagsOutput{
			Tags: []*iam.Tag{
				{
					Key:   aws.String("SomethingThatIsNotVersion"),
					Value: aws.String("v1.3.0"),
				},
				{
					Key:   aws.String("AnotherNonVersionTag"),
					Value: aws.String("totally not relevant"),
				},
			},
		}
		mockClient.On("ListRoleTagsWithContext", mock.Anything, mock.Anything).Return(mockResp, nil)
		v, err := cache.getWithClient(ctx, mockClient)
		require.Equal(t, errVersionNotFound, err)
		assert.Nil(t, v)
	})
}

func TestRequire(t *testing.T) {
	ctx := context.Background()
	t.Run("HappyPathRequirePasses", func(t *testing.T) {
		mockClient := &mockIAM{}
		mockResp := &iam.ListRoleTagsOutput{
			Tags: []*iam.Tag{
				{
					Key:   aws.String("Version"),
					Value: aws.String("v1.3.0"),
				},
			},
		}
		mockClient.On("ListRoleTagsWithContext", mock.Anything, mock.Anything).Return(mockResp, nil)
		for _, requiredVersion := range []string{"v0.1.0", "v1.2.0", "v1.3.0"} {
			b, err := requireWithClient(ctx, mockClient, requiredVersion)
			require.NoError(t, err)
			assert.True(t, b)
		}
	})

	t.Run("HappyPathRequireFails", func(t *testing.T) {
		mockClient := &mockIAM{}
		mockResp := &iam.ListRoleTagsOutput{
			Tags: []*iam.Tag{
				{
					Key:   aws.String("Version"),
					Value: aws.String("v1.5.0"),
				},
			},
		}
		mockClient.On("ListRoleTagsWithContext", mock.Anything, mock.Anything).Return(mockResp, nil)
		for _, requiredVersion := range []string{"v1.6.0", "v2.0.0", "v1.5.1"} {
			b, err := requireWithClient(ctx, mockClient, requiredVersion)
			require.NoError(t, err)
			assert.False(t, b)
		}
	})

	t.Run("HappyPathAccessDenied", func(t *testing.T) {
		mockClient := &mockIAM{}
		mockClient.On("ListRoleTagsWithContext", mock.Anything, mock.Anything).Return(nil, awserr.New("AccessDenied", "access denied", errors.New("fail")))
		b, err := requireWithClient(ctx, mockClient, "v1.2.3")
		require.NoError(t, err)
		assert.False(t, b)
	})

	t.Run("HappyPathVersionNotFound", func(t *testing.T) {
		mockClient := &mockIAM{}
		mockResp := &iam.ListRoleTagsOutput{
			Tags: []*iam.Tag{},
		}
		mockClient.On("ListRoleTagsWithContext", mock.Anything, mock.Anything).Return(mockResp, nil)
		b, err := requireWithClient(ctx, mockClient, "v1.2.3")
		require.NoError(t, err)
		assert.False(t, b)
	})

	t.Run("SadPathGenericError", func(t *testing.T) {
		mockClient := &mockIAM{}
		mockClient.On("ListRoleTagsWithContext", mock.Anything, mock.Anything).Return(nil, errors.New("something bad happened"))
		b, err := requireWithClient(ctx, mockClient, "v1.2.3")
		require.Error(t, err)
		assert.False(t, b)
	})
}

func TestVersionCache(t *testing.T) {
	ctx := context.Background()
	// only allow the IAM call to happen once

	t.Run("HappyPath", func(t *testing.T) {
		c := &versionCache{skipCache: false}
		mockClient := &mockIAM{}
		mockResp := &iam.ListRoleTagsOutput{
			Tags: []*iam.Tag{
				{
					Key:   aws.String("Version"),
					Value: aws.String("v1.3.0"),
				},
			},
		}
		mockClient.On("ListRoleTagsWithContext", mock.Anything, mock.Anything).Return(mockResp, nil).Once()
		v, err := c.getWithClient(ctx, mockClient)
		for i := 0; i < 5; i++ {
			assert.NoError(t, err)
			assert.NotNil(t, v)
			mockClient.AssertNumberOfCalls(t, "ListRoleTagsWithContext", 1)
		}
	})

	t.Run("SadPath", func(t *testing.T) {
		c := &versionCache{skipCache: false}
		mockClient := &mockIAM{}
		mockClient.On("ListRoleTagsWithContext", mock.Anything, mock.Anything).Return(nil, errors.New("oh no something bad")).Once()
		v, err := c.getWithClient(ctx, mockClient)
		for i := 0; i < 5; i++ {
			assert.Error(t, err)
			assert.Nil(t, v)
			mockClient.AssertNumberOfCalls(t, "ListRoleTagsWithContext", 1)
		}
	})

}
