package policy

import (
	"encoding/json"
	"testing"

	"code.justin.tv/eventbus/controlplane/infrastructure/mocks"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/sns"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
)

// Note that this policy flexes both cases where the Action is an array and a string
const testPolicyString = `
{
	"Version": "2012-10-17",
	"Id": "arn:aws:sqs:us-west-2:297385687169:awesome-cool-queue/SQSDefaultPolicy",
	"Statement": [
		{
			"Sid": "thing",
			"Effect": "Allow",
			"Principal": {
				"AWS": "arn:aws:iam::297385687169:root"
			},
			"Action": [
				"SQS:GetQueueAttributes",
				"SQS:GetQueueUrl"
			],
			"Resource": "arn:aws:sqs:us-west-2:297385687169:awesome-cool-queue"
		},
		{
			"Sid": "MyAmazingLabel",
			"Effect": "Allow",
			"Principal": {
			  "AWS": "arn:aws:iam::859517684765:root"
			},
			"Action": "SQS:SendMessage",
			"Resource": "arn:aws:sqs:us-west-2:297385687169:awesome-cool-queue"
		}
	]
}
`

func TestPolicyMarshaling(t *testing.T) {
	p := &Policy{}

	err := json.Unmarshal([]byte(testPolicyString), p)
	require.NoError(t, err)

	assert.Equal(t, 2, len(p.Statement))
	assert.Equal(t, 2, len(p.Statement[0].Action))
	assert.Equal(t, 1, len(p.Statement[1].Action))

	// Marshal back to bytes, re-unmarshal, marshal again and make sure the bytes are the same
	b, err := json.Marshal(p)
	require.NoError(t, err)

	p = &Policy{}
	err = json.Unmarshal([]byte(testPolicyString), p)
	require.NoError(t, err)

	b2, err := json.Marshal(p)
	require.NoError(t, err)

	assert.Equal(t, b, b2)
}

func TestPolicyContainsSID(t *testing.T) {
	p := &Policy{}

	err := json.Unmarshal([]byte(testPolicyString), p)
	require.NoError(t, err)

	assert.True(t, p.ContainsSID("thing"))
	assert.True(t, p.ContainsSID("MyAmazingLabel"))
	assert.False(t, p.ContainsSID("NonExistentSID"))
}

func TestGetQueuePolicy(t *testing.T) {
	ctx := mocks.DefaultBehavior()
	queuePolicy := &Policy{
		Statement: []*PolicyStatement{
			{
				SID: "amazing-sid",
			},
		},
	}
	queuePolicyBytes, err := json.Marshal(queuePolicy)
	require.NoError(t, err)

	// Create a mock SQS client and outline some default behaviors
	testSQS := mocks.NewMockSQS()
	respQueueAttr := &sqs.GetQueueAttributesOutput{
		Attributes: map[string]*string{
			"Policy": aws.String(string(queuePolicyBytes)),
		},
	}
	testSQS.On("GetQueueAttributesWithContext", mock.MatchedBy(mocks.IsDefaultBehavior(mocks.SQSGetQueueAttributes)), mock.Anything, mock.Anything).Return(respQueueAttr, nil)

	t.Run("HappyPath", func(t *testing.T) {
		p, err := GetQueuePolicy(ctx, testSQS, "validURL")
		assert.NoError(t, err)
		assert.Equal(t, "amazing-sid", p.Statement[0].SID)
	})

	t.Run("ErrorAPI", func(t *testing.T) {
		// Run the function with a context indicating the mock should error on sqs:GetQueueAttributes
		p, err := GetQueuePolicy(mocks.WithBehavior(ctx, mocks.SQSGetQueueAttributes, mocks.Error), testSQS, "someURL")
		assert.Error(t, err)
		assert.Nil(t, p)
	})

	// Sad path, bad unmarshal data
	// for this create a custom behavior
	t.Run("ErrorUnmarshal", func(t *testing.T) {
		// Define a custom behavior and return value for GetQueueAttributes
		badReturnData := "sqs-get-queue-attrs-bad-data"
		badResp := &sqs.GetQueueAttributesOutput{
			Attributes: map[string]*string{
				"Policy": aws.String("~you cant unmarshal this!~"),
			},
		}
		// Register the custom behavior to GetQueueAttributes
		testSQS.On("GetQueueAttributesWithContext", mock.MatchedBy(mocks.IsBehavior(mocks.SQSGetQueueAttributes, badReturnData)), mock.Anything, mock.Anything).Return(badResp, nil)
		// Call the function we are testing using our custom behavior
		p, err := GetQueuePolicy(mocks.WithBehavior(ctx, mocks.SQSGetQueueAttributes, badReturnData), testSQS, "validURL")
		assert.Error(t, err)
		assert.Nil(t, p)
	})
}

func TestGetTopicPolicy(t *testing.T) {
	ctx := mocks.DefaultBehavior()
	topicPolicy := &Policy{
		Statement: []*PolicyStatement{
			{
				SID: "amazing-sid",
			},
		},
	}
	topicPolicyBytes, err := json.Marshal(topicPolicy)
	require.NoError(t, err)

	// Create a mock SQS client and outline some default behaviors
	testSNS := mocks.NewMockSNS()
	respTopicAttr := &sns.GetTopicAttributesOutput{
		Attributes: map[string]*string{
			"Policy": aws.String(string(topicPolicyBytes)),
		},
	}
	testSNS.On("GetTopicAttributesWithContext", mock.MatchedBy(mocks.IsDefaultBehavior(mocks.SNSGetTopicAttributes)), mock.Anything, mock.Anything).Return(respTopicAttr, nil)

	t.Run("HappyPath", func(t *testing.T) {
		p, err := GetTopicPolicy(ctx, testSNS, "validURL")
		assert.NoError(t, err)
		assert.Equal(t, "amazing-sid", p.Statement[0].SID)
	})

	t.Run("ErrorAPI", func(t *testing.T) {
		// Run the function with a context indicating the mock should error on sqs:GetQueueAttributes
		p, err := GetTopicPolicy(mocks.WithBehavior(ctx, mocks.SNSGetTopicAttributes, mocks.Error), testSNS, "someURL")
		assert.Error(t, err)
		assert.Nil(t, p)
	})

	// Sad path, bad unmarshal data
	// for this create a custom behavior
	t.Run("ErrorUnmarshal", func(t *testing.T) {
		// Define a custom behavior and return value for GetQueueAttributes
		badReturnData := "sqs-get-queue-attrs-bad-data"
		badResp := &sns.GetTopicAttributesOutput{
			Attributes: map[string]*string{
				"Policy": aws.String("~you cant unmarshal this!~"),
			},
		}
		// Register the custom behavior to GetQueueAttributes
		testSNS.On("GetTopicAttributesWithContext", mock.MatchedBy(mocks.IsBehavior(mocks.SNSGetTopicAttributes, badReturnData)), mock.Anything, mock.Anything).Return(badResp, nil)
		// Call the function we are testing using our custom behavior
		p, err := GetTopicPolicy(mocks.WithBehavior(ctx, mocks.SNSGetTopicAttributes, badReturnData), testSNS, "validURL")
		assert.Error(t, err)
		assert.Nil(t, p)
	})
}

func TestPolicyStatements(t *testing.T) {
	t.Run("Unmarshal Various Cases", func(t *testing.T) {
		t.Run("Typical SQS own policy", func(t *testing.T) {
			statement, err := unmarshalStatement(`{
				"Sid": "thing",
				"Effect": "Allow",
				"Principal": {"AWS": "arn:aws:iam::1234:root"},
				"Action": ["SQS:GetQueueAttributes","SQS:GetQueueUrl"],
				"Resource": "arn:aws:abc"
			}`)
			require.NoError(t, err)
			assert.Equal(t, "Allow", statement.Effect)
			assert.Equal(t, "arn:aws:abc", statement.Resource)
			assert.Len(t, statement.Principal.AWS, 1)
			assert.True(t, statement.Principal.AWS.Contains("arn:aws:iam::1234:root"))
		})

		t.Run("Typical SQS SendMessage policy", func(t *testing.T) {
			statement, err := unmarshalStatement(`{
				"Sid": "abcd",
				"Effect": "Allow",
				"Principal": {"AWS": "*"},
				"Action": "SQS:SendMessage",
				"Resource": "arn:aws:abc",
				"Condition": {
					"ArnEquals": {
						"aws:SourceArn": "arn:aws:sns:us-west-2:1234:eventbus_production_ClockTick"
					}
				}
			}`)
			require.NoError(t, err)
			assert.Equal(t, "Allow", statement.Effect)
			assert.Len(t, statement.Principal.AWS, 1)
			assert.Equal(t, "*", statement.Principal.AWS[0])
			assert.NotNil(t, statement.Condition)
			assert.Len(t, statement.Condition.ArnEquals, 1)
			sourceArns := *statement.Condition.ArnEquals["aws:SourceArn"]
			assert.Len(t, sourceArns, 1)
			assert.Equal(t, "arn:aws:sns:us-west-2:1234:eventbus_production_ClockTick", sourceArns[0])
		})

		t.Run("Typical SQS SendMessage policy with multi source ARNs", func(t *testing.T) {
			statement, err := unmarshalStatement(`{
				"Sid": "abcd",
				"Effect": "Allow",
				"Principal": {"AWS": "*"},
				"Action": "SQS:SendMessage",
				"Resource": "arn:aws:abc",
				"Condition": {
					"ArnEquals": {
						"aws:SourceArn": [
							"arn:aws:sns:us-west-2:1234:eventbus_local_ClockTick",
							"arn:aws:sns:us-west-2:1234:eventbus_local_QuoteCreate",
							"arn:aws:sns:us-west-2:1234:eventbus_local_FooBar"
						]
					}
				}
			}`)
			require.NoError(t, err)
			assert.Len(t, statement.Principal.AWS, 1)
			assert.Equal(t, "*", statement.Principal.AWS[0])
			assert.NotNil(t, statement.Condition)
			assert.Len(t, statement.Condition.ArnEquals, 1)
			sourceArns := *statement.Condition.ArnEquals["aws:SourceArn"]
			assert.Len(t, sourceArns, 3)
			assert.Equal(t, "arn:aws:sns:us-west-2:1234:eventbus_local_ClockTick", sourceArns[0])
			assert.Equal(t, "arn:aws:sns:us-west-2:1234:eventbus_local_QuoteCreate", sourceArns[1])
			assert.Equal(t, "arn:aws:sns:us-west-2:1234:eventbus_local_FooBar", sourceArns[2])
		})

		t.Run("Typical SNS Subscribe Policy", func(t *testing.T) {
			statement, err := unmarshalStatement(`{
				"Sid": "target-subscribe-abc",
				"Effect": "Allow",
				"Principal": {
					"AWS": ["arn:aws:iam::1234:root", "arn:aws:iam::5678:root"]
				},
				"Action": "SNS:Subscribe",
				"Resource": "arn:aws:sns:us-west-2:1234:eventbus_local_ClockTick"
			}`)
			require.NoError(t, err)
			assert.Equal(t, OneOrManyString{"arn:aws:iam::1234:root", "arn:aws:iam::5678:root"}, statement.Principal.AWS)
			assert.Nil(t, statement.Condition)
			assert.Equal(t, OneOrManyString{"SNS:Subscribe"}, statement.Action)
		})

		t.Run("SNS Default Statement", func(t *testing.T) {
			statement, err := unmarshalStatement(`{
				"Sid": "__default_statement_ID",
				"Effect": "Allow",
				"Principal": {"AWS": "*"},
				"Action": ["SNS:GetTopicAttributes","SNS:SetTopicAttributes"],
				"Resource": "arn:aws:sns:us-west-2:297385687169:eventbus_production_ClockTick",
				"Condition": {
					"StringEquals": {
						"AWS:SourceOwner": "1234"
					}
				}

			}`)
			require.NoError(t, err)
			assert.Len(t, statement.Condition.StringEquals, 1)
			sourceOwner := statement.Condition.StringEquals["AWS:SourceOwner"]
			assert.NotNil(t, sourceOwner)
			assert.Len(t, *sourceOwner, 1)
			assert.True(t, sourceOwner.Contains("1234"))
		})
	})

	t.Run("SpecialMarshaling", func(t *testing.T) {
		t.Run("Empty action", func(t *testing.T) {
			statement := &PolicyStatement{
				SID:    "",
				Effect: "Allow",
				Principal: &PolicyStatementPrincipal{
					AWS: OneOrManyString{"*"},
				},
				Action: OneOrManyString{},
			}
			buf, err := json.Marshal(statement)
			require.NoError(t, err)
			assert.Equal(t, `{"Sid":"","Principal":{"AWS":"*"},"Action":[],"Effect":"Allow"}`, string(buf))
		})
	})
}

func TestOneOrManyString(t *testing.T) {
	t.Run("Remove", func(t *testing.T) {
		a := OneOrManyString{"a", "b", "c"}
		a.Remove("a")
		assert.Equal(t, OneOrManyString{"b", "c"}, a)

		b := OneOrManyString{"a", "b", "c"}
		b.Remove("b")
		assert.Equal(t, OneOrManyString{"a", "c"}, b)

		c := OneOrManyString{"a", "b", "c"}
		c.Remove("c")
		assert.Equal(t, OneOrManyString{"a", "b"}, c)

		d := OneOrManyString{"d"}
		d.Remove("d")
		assert.Len(t, d, 0)
	})
}

func TestPrincipalUnmarshal(t *testing.T) {
	validPolicy := `
{
	"Version": "2012-10-17",
	"Statement": [{
		"Effect": "Allow",
		"Principal": "*",
		"Action": "sqs:SendMessage",
		"Resource": "*",
		"Condition":{
			"ArnEquals":{
				"aws:SourceArn": "arn:aws:*:*:${event_bus_account_id}:*"
			}
		}
	}]
}`
	p := &Policy{}

	err := json.Unmarshal([]byte(validPolicy), p)
	require.NoError(t, err)

	assert.Equal(t, 1, len(p.Statement))
	assert.Equal(t, 1, len(p.Statement[0].Action))
	assert.Equal(t, 1, len(p.Statement[0].Principal.AWS))
	assert.True(t, p.Statement[0].Principal.AWS.Contains("*"))
}

func unmarshalStatement(input string) (*PolicyStatement, error) {
	var statement PolicyStatement
	err := json.Unmarshal([]byte(input), &statement)
	return &statement, err
}
