package policy

import (
	"context"
	"encoding/json"
	"fmt"

	kmsclient "code.justin.tv/eventbus/controlplane/internal/clients/kms"
	"code.justin.tv/eventbus/controlplane/internal/containers"
	"code.justin.tv/eventbus/controlplane/internal/logger"
	"code.justin.tv/eventbus/controlplane/internal/sqsutil"
	"code.justin.tv/eventbus/controlplane/internal/uuid"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/kms"
	"github.com/aws/aws-sdk-go/service/sns"
	"github.com/aws/aws-sdk-go/service/sns/snsiface"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
	"github.com/pkg/errors"
	"go.uber.org/zap"
)

const sqsDefaultPolicyStatementSid = "eventbus"

// This file provides utility structures to marshal and unmarshal AWS policy documents

type Policy struct {
	Version   string             `json:"Version"`
	ID        string             `json:"Id"`
	Statement []*PolicyStatement `json:"Statement,omitempty"`
}

type PolicyStatementPrincipal struct {
	AWS     OneOrManyString `json:"AWS,omitempty"`
	Service OneOrManyString `json:"Service,omitempty"`
}

type PolicyStatementCondition struct {
	ArnEquals    map[string]*OneOrManyString `json:"ArnEquals,omitempty"`
	StringEquals map[string]*OneOrManyString `json:"StringEquals,omitempty"`
}

type PolicyStatement struct {
	SID       string                    `json:"Sid"`
	Principal *PolicyStatementPrincipal `json:"Principal,omitempty"`
	Action    OneOrManyString           `json:"Action"`
	Effect    string                    `json:"Effect"`
	Resource  string                    `json:"Resource,omitempty"`
	Condition *PolicyStatementCondition `json:"Condition,omitempty"`
}

type OneOrManyString []string

func (s *OneOrManyString) UnmarshalJSON(b []byte) error {
	if b[0] == '[' {
		var actions []string
		err := json.Unmarshal(b, &actions)
		if err != nil {
			return err
		}
		*s = actions
	} else {
		*s = []string{string(b[1 : len(b)-1])}
	}
	return nil
}

func (s *OneOrManyString) MarshalJSON() ([]byte, error) {
	if len(*s) == 1 {
		actions := *s
		return json.Marshal(actions[0])
	}
	actions := []string(*s)
	return json.Marshal(actions)
}

func (s *OneOrManyString) Remove(query string) {
	(*containers.StringArray)(s).Remove(query)
}

func (s OneOrManyString) Contains(query string) bool {
	return containers.StringArray(s).Contains(query)
}

func (s OneOrManyString) ContainsEqualFold(query string) bool {
	return containers.StringArray(s).ContainsEqualFold(query)
}

func (p *PolicyStatementPrincipal) UnmarshalJSON(b []byte) error {
	switch b[0] {
	case '"':
		return json.Unmarshal(b, &p.AWS)
	case '{':
		var dest struct {
			AWS     OneOrManyString `json:"AWS"`
			Service OneOrManyString `json:"Service"`
		}
		err := json.Unmarshal(b, &dest)
		p.AWS = dest.AWS
		p.Service = dest.Service
		return err
	}

	return errors.New("unrecognized value for Principal")
}

func (p *Policy) ContainsSID(sid string) bool {
	for _, statement := range p.Statement {
		if statement.SID == sid {
			return true
		}
	}
	return false
}

// FindStatement will return a reference to the statement with the given SID, or nil
func (p *Policy) FindStatement(sid string) (*PolicyStatement, int) {
	for i, statement := range p.Statement {
		if statement.SID == sid {
			return statement, i
		}
	}
	return nil, -1
}

// RemoveStatement will remove the named statement from this policy.
// Returning true if it was found and removed, false otherwise.
func (p *Policy) RemoveStatement(sid string) bool {
	stmt, idx := p.FindStatement(sid)
	if stmt != nil {
		// this is the canonical way to remove any index from a slice
		p.Statement = append(p.Statement[:idx], p.Statement[idx+1:]...)
		return true
	}
	return false
}

func DefaultQueuePolicy(awsAccountID string) *Policy {
	sourceARN := fmt.Sprintf("arn:aws:*:*:%s:*", awsAccountID)

	statement := &PolicyStatement{
		SID:    sqsDefaultPolicyStatementSid,
		Effect: "Allow",
		Principal: &PolicyStatementPrincipal{
			AWS: OneOrManyString{"*"},
		},
		Action:   OneOrManyString{"sqs:SendMessage"},
		Resource: "*",
		Condition: &PolicyStatementCondition{
			ArnEquals: map[string]*OneOrManyString{
				"aws:SourceArn": {sourceARN},
			},
		},
	}

	return &Policy{
		ID:        uuid.NewUUIDMust(),
		Version:   "2008-10-17",
		Statement: []*PolicyStatement{statement},
	}
}

func GetQueuePolicy(ctx context.Context, client sqsiface.SQSAPI, queueURL string) (*Policy, error) {
	queueAttrs, err := client.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{
		QueueUrl:       aws.String(queueURL),
		AttributeNames: aws.StringSlice([]string{"Policy"}),
	})
	if err != nil {
		return nil, errors.Wrap(err, "could not fetch sqs target attributes")
	}
	queuePolicy := &Policy{}
	policyJSON := aws.StringValue(queueAttrs.Attributes[sqsutil.KeyPolicy])

	// Default policy is the empty string, we can simply avoid the extra unmarshaling here
	if policyJSON != "" {
		err = json.Unmarshal([]byte(policyJSON), queuePolicy)
		if err != nil {
			return nil, errors.Wrap(err, "could not unmarshal queue policy")
		}
	}
	return queuePolicy, nil
}

func SetQueuePolicy(ctx context.Context, client sqsiface.SQSAPI, queueURL string, policy *Policy) error {
	buf, err := json.Marshal(policy)
	if err != nil {
		return errors.Wrapf(err, "could not marshal topic policy %#v to JSON", policy)
	}
	strPolicy := string(buf)

	log := logger.FromContext(ctx).With(zap.String("jsonPolicy", strPolicy))

	_, err = client.SetQueueAttributesWithContext(ctx, &sqs.SetQueueAttributesInput{
		QueueUrl: aws.String(queueURL),
		Attributes: map[string]*string{
			sqsutil.KeyPolicy: aws.String(strPolicy),
		},
	})

	if err != nil {
		log.Warn("error setting SQS queue policy", zap.Error(err))
		return errors.Wrap(err, "could not set SQS queue policy")
	}

	return nil
}

func GetTopicPolicy(ctx context.Context, client snsiface.SNSAPI, topicArn string) (*Policy, error) {
	// TODO: there is opportunity to share logic with `topics.go` here. ASYNC-127
	topicAttrs, err := client.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{
		TopicArn: aws.String(topicArn),
	})
	if err != nil {
		return nil, err
	}
	topicPolicy := &Policy{}
	err = json.Unmarshal([]byte(aws.StringValue(topicAttrs.Attributes["Policy"])), topicPolicy)
	if err != nil {
		return nil, err
	}
	return topicPolicy, nil
}

func SetTopicPolicy(ctx context.Context, client snsiface.SNSAPI, topicArn string, policy *Policy) error {
	buf, err := json.Marshal(policy)
	if err != nil {
		return errors.Wrapf(err, "could not marshal topic policy %#v to JSON", policy)
	}
	strPolicy := string(buf)

	log := logger.FromContext(ctx).With(zap.String("jsonPolicy", strPolicy))

	_, err = client.SetTopicAttributesWithContext(ctx, &sns.SetTopicAttributesInput{
		TopicArn:       aws.String(topicArn),
		AttributeName:  aws.String("Policy"),
		AttributeValue: aws.String(strPolicy),
	})

	if err != nil {
		log.Warn("error setting SNS topic policy", zap.Error(err))
		return errors.Wrap(err, "could not set SNS topic policy")
	}

	log.Info("set topic policy")
	return nil
}

func GetCMKPolicy(ctx context.Context, client kmsclient.KMSAPI, encryptionKeyID *string) (*Policy, error) {
	getKeyPolicyOutput, err := client.GetKeyPolicyWithContext(ctx, &kms.GetKeyPolicyInput{
		KeyId:      encryptionKeyID,
		PolicyName: aws.String("default"),
	})
	if err != nil {
		return nil, errors.Wrap(err, "could not get key policy "+*encryptionKeyID)
	}

	keyPolicy := &Policy{}
	err = json.Unmarshal([]byte(aws.StringValue(getKeyPolicyOutput.Policy)), &keyPolicy)
	if err != nil {
		return nil, errors.Wrap(err, "could not unmarshal key policy for "+*encryptionKeyID)
	}
	return keyPolicy, nil
}

func SetCMKPolicy(ctx context.Context, client kmsclient.KMSAPI, encryptionKeyID *string, policy *Policy) error {
	buf, err := json.Marshal(policy)
	if err != nil {
		return errors.Wrapf(err, "could not marshal topic policy %#v to JSON", policy)
	}
	strPolicy := string(buf)

	_, err = client.PutKeyPolicyWithContext(ctx, &kms.PutKeyPolicyInput{
		KeyId:      encryptionKeyID,
		PolicyName: aws.String("default"),
		Policy:     aws.String(strPolicy),
	})
	if err != nil {
		log := logger.FromContext(ctx).With(zap.String("jsonPolicy", strPolicy))
		log.Warn("Error setting CMK policy", zap.Error(err))
		return errors.Wrap(err, "could not set cmk policy")
	}

	return nil
}
