package sns

import (
	"context"
	"strconv"

	"code.justin.tv/eventbus/controlplane/internal/arn"
	"code.justin.tv/eventbus/controlplane/internal/logger"
	"code.justin.tv/eventbus/controlplane/internal/policy"
	"github.com/aws/aws-sdk-go/aws"
	awsarn "github.com/aws/aws-sdk-go/aws/arn"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/service/sns"
	"github.com/aws/aws-sdk-go/service/sns/snsiface"
	"github.com/pkg/errors"
	"go.uber.org/zap"
)

type CredentialFetcher interface {
	AssumeRoleCredentials(accountID string) *credentials.Credentials
}

type Manager struct {
	// baseSession is used for creating clients that use credentials assumed in other accounts
	baseSession            client.ConfigProvider
	client                 snsiface.SNSAPI
	credentialFetcher      CredentialFetcher
	encryptionAtRestKeyARN string
	logger                 *logger.Logger
}

func NewManager(sess client.ConfigProvider, credFetcher CredentialFetcher, encAtRestKeyARN string, logger *logger.Logger) *Manager {
	return &Manager{
		baseSession:            sess,
		client:                 sns.New(sess),
		credentialFetcher:      credFetcher,
		encryptionAtRestKeyARN: encAtRestKeyARN,
		logger:                 logger,
	}
}

func (m *Manager) AllowAccountPublish(ctx context.Context, topicARN, awsAccountID string) error {
	return m.addRootARNToStatement(ctx, topicARN, arn.IAMRootARN(awsAccountID), snsPublishStatementSID)
}

func (m *Manager) DisallowAccountPublish(ctx context.Context, topicARN, awsAccountID string) error {
	return m.removeRootARNFromStatement(ctx, topicARN, arn.IAMRootARN(awsAccountID), snsPublishStatementSID)
}

func (m *Manager) AllowAccountSubscribe(ctx context.Context, topicARN, awsAccountID string) error {
	return m.addRootARNToStatement(ctx, topicARN, arn.IAMRootARN(awsAccountID), snsSubscribeStatementSID)
}

func (m *Manager) DisallowAccountSubscribe(ctx context.Context, topicARN, awsAccountID string) error {
	return m.removeRootARNFromStatement(ctx, topicARN, arn.IAMRootARN(awsAccountID), snsSubscribeStatementSID)
}

func (m *Manager) SubscriptionExists(ctx context.Context, topicARN, subscriptionARN string) (bool, error) {
	var found bool
	pageHandler := func(output *sns.ListSubscriptionsByTopicOutput, lastPage bool) bool {
		for _, subscription := range output.Subscriptions {
			if aws.StringValue(subscription.SubscriptionArn) == subscriptionARN {
				found = true
				return false // stop paging
			}
		}

		return !lastPage
	}

	input := &sns.ListSubscriptionsByTopicInput{
		TopicArn: aws.String(topicARN),
	}
	err := m.client.ListSubscriptionsByTopicPagesWithContext(ctx, input, pageHandler)
	if err != nil {
		return false, err
	}

	return found, nil
}

func (m *Manager) Subscribe(ctx context.Context, topicARN, queueARN string) (string, error) {
	assumedClient, err := m.assumedClient(queueARN)
	if err != nil {
		return "", errors.Wrap(err, "could not create sns client")
	}

	resp, err := assumedClient.SubscribeWithContext(ctx, &sns.SubscribeInput{
		TopicArn:              aws.String(topicARN),
		Endpoint:              aws.String(queueARN),
		Protocol:              aws.String("sqs"),
		ReturnSubscriptionArn: aws.Bool(true),
	})
	if err != nil {
		return "", errors.Wrap(err, "could not subscribe")
	}

	return aws.StringValue(resp.SubscriptionArn), nil
}

func (m *Manager) Unsubscribe(ctx context.Context, queueARN, subscriptionARN string) error {
	assumedClient, err := m.assumedClient(queueARN)
	if err != nil {
		return errors.Wrap(err, "could not create sns client")
	}

	_, err = assumedClient.UnsubscribeWithContext(ctx, &sns.UnsubscribeInput{
		SubscriptionArn: aws.String(subscriptionARN),
	})
	if err != nil {
		return errors.Wrap(err, "could not unsubscribe")
	}

	return nil
}

func (m *Manager) CreateTopic(ctx context.Context, eventType, environment string) (string, error) {
	topicName := "eventbus-" + environment + "-" + eventType
	input := &sns.CreateTopicInput{
		Name: aws.String(topicName),
		Attributes: map[string]*string{
			"KmsMasterKeyId": aws.String(m.encryptionAtRestKeyARN),
		},
	}

	output, err := m.client.CreateTopicWithContext(ctx, input)
	if err != nil {
		return "", err
	}

	return aws.StringValue(output.TopicArn), nil
}

func (m *Manager) TopicExists(ctx context.Context, topicARN string) (bool, error) {
	input := &sns.GetTopicAttributesInput{
		TopicArn: aws.String(topicARN),
	}

	_, err := m.client.GetTopicAttributesWithContext(ctx, input)
	if err != nil {
		if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == sns.ErrCodeNotFoundException {
			return false, nil
		}
		return false, err
	}
	return true, nil
}

func (m *Manager) addRootARNToStatement(ctx context.Context, topicARN, rootARN, statementID string) error {
	topicPolicy, err := policy.GetTopicPolicy(ctx, m.client, topicARN)
	if err != nil {
		return errors.Wrap(err, "cannot get topic policy")
	}

	permissions, found := permissionsForStatement[statementID]
	if !found {
		return errors.Wrap(err, "no permissions associated with statement id")
	}

	statement, _ := topicPolicy.FindStatement(statementID)
	if statement == nil {
		statement = &policy.PolicyStatement{
			SID:       statementID,
			Effect:    "Allow",
			Principal: &policy.PolicyStatementPrincipal{},
			Action:    policy.OneOrManyString(permissions),
			Resource:  topicARN,
		}
		topicPolicy.Statement = append(topicPolicy.Statement, statement)
	}

	if statement.Principal.AWS.Contains(rootARN) {
		m.logInfo(
			"aws account root arn already in topic policy",
			zap.String("topicArn", topicARN),
			zap.String("rootArn", rootARN),
			zap.String("statementId", statementID),
		)
		return nil
	}

	statement.Principal.AWS = append(statement.Principal.AWS, rootARN)

	if len(statement.Principal.AWS) > snsPolicyMaxPrincipals {
		return errors.New("too many principals in sns policy, max " + strconv.Itoa(snsPolicyMaxPrincipals))
	}

	err = policy.SetTopicPolicy(ctx, m.client, topicARN, topicPolicy)
	if err != nil {
		return errors.Wrap(err, "could not add permissions to topic policy")
	}

	return nil
}

func (m *Manager) removeRootARNFromStatement(ctx context.Context, topicARN, rootARN, statementID string) error {
	topicPolicy, err := policy.GetTopicPolicy(ctx, m.client, topicARN)
	if err != nil {
		return errors.Wrap(err, "could not get topic policy")
	}

	statement, _ := topicPolicy.FindStatement(statementID)
	if statement == nil {
		return nil
	}

	if !statement.Principal.AWS.Contains(rootARN) {
		m.logInfo(
			"account root arn not found in topic policy",
			zap.String("topicARN", topicARN),
			zap.String("rootArn", rootARN),
			zap.Any("statement", statement),
			zap.Any("policy", topicPolicy),
		)
		return nil
	}

	if len(statement.Principal.AWS) == 1 {
		// we're about to remove the last Principal, have to remove the whole statement
		topicPolicy.RemoveStatement(statementID)
	} else {
		statement.Principal.AWS.Remove(rootARN)
	}

	err = policy.SetTopicPolicy(ctx, m.client, topicARN, topicPolicy)
	if err != nil {
		return errors.Wrap(err, "could not set topic policy")
	}

	return nil
}

func (m *Manager) logInfo(msg string, fields ...zap.Field) {
	if m.logger != nil {
		m.logger.Info(msg, fields...)
	}
}

func (m *Manager) assumedClient(arn string) (snsiface.SNSAPI, error) {
	subARN, err := awsarn.Parse(arn)
	if err != nil {
		return nil, err
	}

	creds := m.credentialFetcher.AssumeRoleCredentials(subARN.AccountID)
	return sns.New(m.baseSession, &aws.Config{Credentials: creds}), nil
}
