package sns

import (
	"context"
	"encoding/json"
	"testing"

	"code.justin.tv/eventbus/controlplane/internal/policy"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/service/sns"
	"github.com/aws/aws-sdk-go/service/sns/snsiface"
	"github.com/pkg/errors"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

type mockSNS struct {
	snsiface.SNSAPI
	policyMap map[string][]byte
}

func (m *mockSNS) GetTopicAttributesWithContext(ctx context.Context, input *sns.GetTopicAttributesInput, opts ...request.Option) (*sns.GetTopicAttributesOutput, error) {
	var p []byte
	var found bool
	if p, found = m.policyMap[aws.StringValue(input.TopicArn)]; !found {
		return nil, awserr.New(sns.ErrCodeNotFoundException, "not found", errors.New("not found"))
	}

	return &sns.GetTopicAttributesOutput{
		Attributes: aws.StringMap(map[string]string{
			"Policy": string(p),
		}),
	}, nil

}

func (m *mockSNS) SetTopicAttributesWithContext(ctx context.Context, input *sns.SetTopicAttributesInput, opts ...request.Option) (*sns.SetTopicAttributesOutput, error) {
	m.policyMap[aws.StringValue(input.TopicArn)] = []byte(aws.StringValue(input.AttributeValue))
	return &sns.SetTopicAttributesOutput{}, nil
}

func (m *mockSNS) assignPolicy(topicARN string, p *policy.Policy) {
	topicPolicyBytes, err := json.Marshal(p)
	if err != nil {
		panic(err)
	}

	m.policyMap[topicARN] = topicPolicyBytes
}

func TestAddRootARNToStatement(t *testing.T) {
	ctx := context.Background()
	topicARN := "topic-arn"
	rootARN1 := "root-arn"
	rootARN2 := "new-root-arn"
	mockSNS := &mockSNS{
		policyMap: make(map[string][]byte),
	}
	manager := &Manager{
		client: mockSNS,
	}

	topicPolicy := &policy.Policy{
		Statement: []*policy.PolicyStatement{
			{
				SID: "amazing-sid",
			},
		},
	}

	mockSNS.assignPolicy(topicARN, topicPolicy)

	t.Run("AddPublishPermission", func(t *testing.T) {
		t.Run("First publisher", func(t *testing.T) {
			err := manager.addRootARNToStatement(ctx, topicARN, rootARN1, snsPublishStatementSID)
			assert.NoError(t, err)

			p := &policy.Policy{}
			err = json.Unmarshal(mockSNS.policyMap[topicARN], p)
			require.NoError(t, err)

			assert.True(t, p.ContainsSID(snsPublishStatementSID))
			s, _ := p.FindStatement(snsPublishStatementSID)
			assert.True(t, s.Principal.AWS.Contains(rootARN1))
		})

		t.Run("Second publisher", func(t *testing.T) {

			err := manager.addRootARNToStatement(ctx, topicARN, rootARN2, snsPublishStatementSID)
			assert.NoError(t, err)

			p := &policy.Policy{}
			err = json.Unmarshal(mockSNS.policyMap[topicARN], p)
			require.NoError(t, err)

			s, _ := p.FindStatement(snsPublishStatementSID)
			assert.True(t, s.Principal.AWS.Contains(rootARN2))
			assert.Len(t, s.Principal.AWS, 2)

			assert.Equal(t, []string{"sns:Publish"}, []string(s.Action))
		})

		t.Run("Idempotent", func(t *testing.T) {
			err := manager.addRootARNToStatement(ctx, topicARN, rootARN2, snsPublishStatementSID)
			assert.NoError(t, err)

			p := &policy.Policy{}
			err = json.Unmarshal(mockSNS.policyMap[topicARN], p)
			require.NoError(t, err)

			s, _ := p.FindStatement(snsPublishStatementSID)
			assert.True(t, s.Principal.AWS.Contains(rootARN2))
			assert.Len(t, s.Principal.AWS, 2)
		})
	})

	t.Run("RemovePublishPermission", func(t *testing.T) {
		t.Run("Remove when more than one principal", func(t *testing.T) {
			err := manager.removeRootARNFromStatement(ctx, topicARN, rootARN1, snsPublishStatementSID)
			require.NoError(t, err)

			p := &policy.Policy{}
			err = json.Unmarshal(mockSNS.policyMap[topicARN], p)
			require.NoError(t, err)

			s, _ := p.FindStatement(snsPublishStatementSID)
			assert.False(t, s.Principal.AWS.Contains(rootARN1))
			assert.Len(t, s.Principal.AWS, 1)
		})

		t.Run("Idempotent", func(t *testing.T) {
			err := manager.removeRootARNFromStatement(ctx, topicARN, rootARN1, snsPublishStatementSID)
			require.NoError(t, err)

			p := &policy.Policy{}
			err = json.Unmarshal(mockSNS.policyMap[topicARN], p)
			require.NoError(t, err)

			s, _ := p.FindStatement(snsPublishStatementSID)
			assert.False(t, s.Principal.AWS.Contains(rootARN1))
			assert.Len(t, s.Principal.AWS, 1)
		})

		t.Run("Last principal in statement", func(t *testing.T) {
			err := manager.removeRootARNFromStatement(ctx, topicARN, rootARN2, snsPublishStatementSID)
			require.NoError(t, err)

			p := &policy.Policy{}
			err = json.Unmarshal(mockSNS.policyMap[topicARN], p)
			require.NoError(t, err)

			s, _ := p.FindStatement(snsPublishStatementSID)
			assert.Nil(t, s)
		})
	})
}
