package validation

import (
	"context"

	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/service/sns"
	"github.com/aws/aws-sdk-go/service/sns/snsiface"
	"github.com/pkg/errors"
	"go.uber.org/zap"

	"code.justin.tv/eventbus/controlplane/internal/arn"
	"code.justin.tv/eventbus/controlplane/internal/db"
	"code.justin.tv/eventbus/controlplane/internal/logger"
	"code.justin.tv/eventbus/controlplane/internal/policy"
)

const ErrPublicationNoPolicyStatement = "sns topic publisher policy not found"
const ErrPublicationMissingInPolicy = "sns topic policy does not include this publisher"

const snsPolicySID = "give-publishers-publish"

type Publication struct {
	*db.Publication
	Grantee     string
	EventType   string
	Environment string

	snsTopicARN string
	snsClient   snsiface.SNSAPI
}

func (e *Publication) ID() string {
	return itemID(e)
}

func (e *Publication) Type() string {
	return "Publication"
}

func (e *Publication) Attributes() []*ItemAttribute {
	return []*ItemAttribute{
		{
			Key:   "Grantee",
			Value: e.Grantee,
		},
		{
			Key:   "EventType",
			Value: e.EventType,
		},
		{
			Key:   "Environment",
			Value: e.Environment,
		},
	}
}

func (e *Publication) Validate(ctx context.Context) (*Report, error) {
	// Get Topic attributes
	p, err := policy.GetTopicPolicy(ctx, e.snsClient, e.snsTopicARN)
	if err != nil {
		return nil, errors.Wrap(err, "could not get topic policy")
	}

	var statement *policy.PolicyStatement
	var found int
	if statement, found = p.FindStatement(snsPolicySID); found < 0 {
		return ReportWithEnvironmentSeverity(e, e.Environment, ErrPublicationNoPolicyStatement), nil
	}

	for _, entry := range statement.Principal.AWS {
		acctID, err := arn.AccountID(entry)
		if err != nil {
			logger.FromContext(ctx).Warn("unable to parse account id in sns policy statement", zap.String("unparsableString", entry))
			continue
		}

		if acctID == e.Grantee {
			return Ok(e), nil
		}
	}
	return ReportWithEnvironmentSeverity(e, e.Environment, ErrPublicationMissingInPolicy), nil
}

func Publications(ctx context.Context, sess client.ConfigProvider, db db.DB) ([]Item, error) {
	snsClient := sns.New(sess)

	publications, err := db.Publications(ctx)
	if err != nil {
		return nil, errors.Wrap(err, "could not get publications")
	}

	eventStreams, err := getEventStreamMap(ctx, db)
	if err != nil {
		return nil, errors.Wrap(err, "could not get event streams")
	}

	accounts, err := getAccountMap(ctx, db)
	if err != nil {
		return nil, errors.Wrap(err, "could not get accounts")
	}

	iamRoles, err := getIAMRoleMap(ctx, db)
	if err != nil {
		return nil, errors.Wrap(err, "could not get iam roles")
	}

	items := make([]Item, len(publications))
	for i, publication := range publications {
		var grantee string
		if publication.AccountID.Valid {
			acct, found := accounts.Get(int(publication.AccountID.Int64))
			if !found {
				logger.FromContext(ctx).Warn("could not find account for publication", zap.Object("publication", publication))
				continue
			}
			grantee = acct.AWSAccountID
		}
		if publication.IAMRoleID.Valid {
			iamRole, found := iamRoles.Get(int(publication.IAMRoleID.Int64))
			if !found {
				logger.FromContext(ctx).Warn("could not find iam role for publication", zap.Object("publication", publication))
				continue
			}
			grantee, err = arn.AccountID(iamRole.ARN)
			if err != nil {
				logger.FromContext(ctx).Warn("could not parse account id from publication iam role arn", zap.Object("publication", publication))
				continue
			}
		}

		eventStream, found := eventStreams.Get(publication.EventStreamID)
		if !found {
			logger.FromContext(ctx).Warn("could not find event stream for publication", zap.Object("publication", publication))
			continue
		}

		items[i] = &Publication{
			Publication: publication,
			Grantee:     grantee,
			EventType:   eventStream.EventType.Name,
			Environment: eventStream.Environment,
			snsClient:   snsClient,
			snsTopicARN: eventStream.SNSDetails.SNSTopicARN,
		}
	}

	return items, nil
}
