package validator

import (
	"context"
	"encoding/json"
	"fmt"
	"net/url"
	"strings"

	"github.com/aws/aws-sdk-go/aws/arn"
	"github.com/pkg/errors"

	"code.justin.tv/eventbus/controlplane/internal/policy"
	"code.justin.tv/eventbus/controlplane/internal/sqsutil"
)

var (
	ErrInvalidQueueAttributesInaccessible           = errors.New("queue attributes inaccessible")
	ErrInvalidQueueMissingKMSKey                    = errors.New("queue is missing kms encryption key id")
	ErrInvalidQueueMissingPublishPermissions        = errors.New("queue is missing publish permissions")
	ErrInvalidQueueMissingEventBusInName            = errors.New("queue does not contain eventbus or EventBus in queue name")
	ErrInvalidDeadletterQueueAttributesInaccessible = errors.New("queue's configured deadletter queue attributes inaccessible")
	ErrInvalidDeadletterQueueKMSKey                 = errors.New("queue's configured deadletter queue is missing kms encryption key id")
	ErrInvalidDeadletterQueueMissingEventBusInName  = errors.New("queue's configured deadletter queue is missing eventbus in its name")
)

type SQSManager interface {
	GetQueueAttributes(ctx context.Context, url string) (map[string]string, error)
	GetQueueURL(ctx context.Context, queueName, awsAccountID string) (string, error)
}

var AllQueueErrors = []error{
	ErrInvalidQueueAttributesInaccessible,
	ErrInvalidQueueMissingKMSKey,
	ErrInvalidQueueMissingPublishPermissions,
	ErrInvalidQueueMissingEventBusInName,
	ErrInvalidDeadletterQueueAttributesInaccessible,
	ErrInvalidDeadletterQueueKMSKey,
	ErrInvalidDeadletterQueueMissingEventBusInName,
}

type QueueValidator struct {
	EventBusAWSAccountID   string
	EncryptionAtRestKeyARN string

	SQSManager SQSManager
}

func NewQueueValidator(sqsManager SQSManager, eventbusAWSAccountID, encryptionAtRestKeyARN string) *QueueValidator {
	return &QueueValidator{
		EventBusAWSAccountID:   eventbusAWSAccountID,
		EncryptionAtRestKeyARN: encryptionAtRestKeyARN,
		SQSManager:             sqsManager,
	}
}

// ValidateQueue validates one queue and its deadletters
func (qv *QueueValidator) ValidateQueue(ctx context.Context, queueURL string) error {
	queueName, err := queueNameFromURL(queueURL)
	if err != nil {
		return errors.Wrapf(err, "could not determine queue name from url %s", queueURL)
	} else if !validQueueName(queueName) {
		return ErrInvalidQueueMissingEventBusInName
	}

	attributes, err := qv.SQSManager.GetQueueAttributes(ctx, queueURL)
	if err != nil {
		return ErrInvalidQueueAttributesInaccessible
	}

	queueARN := attributes[sqsutil.KeyQueueARN]
	queueName, err = queueNameFromARN(queueARN)
	if err != nil {
		return errors.Wrapf(err, "could not determine queue name from arn %s", queueARN)
	} else if !validQueueName(queueName) {
		return ErrInvalidQueueMissingEventBusInName
	}

	queueKMSMasterKey := attributes[sqsutil.KeyKMSKeyID]
	if !validKMSMasterKeyAttached(queueKMSMasterKey, qv.EncryptionAtRestKeyARN) {
		return ErrInvalidQueueMissingKMSKey
	}

	policyJSON := attributes[sqsutil.KeyPolicy]
	if policyJSON == "" {
		return ErrInvalidQueueMissingPublishPermissions
	}

	queuePolicy := &policy.Policy{}
	err = json.Unmarshal([]byte(policyJSON), queuePolicy)
	if err != nil {
		return errors.Wrapf(err, "could not unmarshal current queue policy for %s", queueURL)
	} else if !validEventBusQueuePolicyAttached(queuePolicy, qv.EventBusAWSAccountID, queueARN) {
		return ErrInvalidQueueMissingPublishPermissions
	}

	redrivePolicyStr := attributes[sqsutil.KeyRedrivePolicy]
	redrivePolicy, err := sqsutil.ParseRedrivePolicy(redrivePolicyStr)
	if err != nil {
		return errors.Wrapf(err, "could not parse redrive policy for %s", queueURL)
	}

	if redrivePolicy != nil && redrivePolicy.DeadletterTargetARN != "" {
		err = qv.ValidateDeadletterQueue(ctx, redrivePolicy.DeadletterTargetARN)
		if err != nil {
			return err
		}
	}

	return nil
}

func (qv *QueueValidator) ValidateDeadletterQueue(ctx context.Context, deadletterQueueARN string) error {
	parsedARN, err := arn.Parse(deadletterQueueARN)
	if err != nil {
		return errors.Wrap(err, "could not parse deadletter queue arn")
	}
	queueName := parsedARN.Resource

	if !validQueueName(queueName) {
		return ErrInvalidDeadletterQueueMissingEventBusInName
	}

	queueURL, err := qv.SQSManager.GetQueueURL(ctx, queueName, parsedARN.AccountID)
	if err != nil {
		return errors.Wrapf(err, "could not get queue url from deadletter queue %s", queueName)
	}

	attributes, err := qv.SQSManager.GetQueueAttributes(ctx, queueURL)
	if err != nil {
		return ErrInvalidDeadletterQueueAttributesInaccessible
	}

	dlqKMSMasterKey := attributes[sqsutil.KeyKMSKeyID]
	if !validKMSMasterKeyAttached(dlqKMSMasterKey, qv.EncryptionAtRestKeyARN) {
		return ErrInvalidDeadletterQueueKMSKey
	}

	// deadletter queues can be chained, we will check all possible queues
	redrivePolicyStr := attributes[sqsutil.KeyRedrivePolicy]
	redrivePolicy, err := sqsutil.ParseRedrivePolicy(redrivePolicyStr)
	if err != nil {
		return errors.Wrapf(err, "could not parse redrive policy for %s", queueURL)
	}

	if redrivePolicy != nil && redrivePolicy.DeadletterTargetARN != "" {
		return qv.ValidateDeadletterQueue(ctx, redrivePolicy.DeadletterTargetARN)
	}

	return nil
}

func validQueueName(name string) bool {
	return strings.Contains(name, "eventbus") || strings.Contains(name, "EventBus")
}

func validKMSMasterKeyAttached(queueKMSMasterKey, kmsMasterKey string) bool {
	return queueKMSMasterKey == kmsMasterKey
}

func validEventBusQueuePolicyAttached(queuePolicy *policy.Policy, awsAccountID, queueARN string) bool {
	if queuePolicy == nil {
		return false
	}

	for _, statement := range queuePolicy.Statement {
		if containsEventBusPublishPerms(statement, awsAccountID, queueARN) {
			return true
		}
	}

	return false
}

func containsEventBusPublishPerms(statement *policy.PolicyStatement, awsAccountID, queueARN string) bool {
	if statement == nil || statement.Principal == nil || statement.Condition == nil {
		return false
	}

	eventbusSourceARN := fmt.Sprintf("arn:aws:*:*:%s:*", awsAccountID)
	conditionStatement := statement.Condition.ArnEquals["aws:SourceArn"]
	if conditionStatement == nil {
		return false
	}

	effect := statement.Effect == "Allow"
	principal := statement.Principal.AWS.Contains("*")
	action := statement.Action.ContainsEqualFold("sqs:SendMessage") || statement.Action.Contains("*")
	resource := statement.Resource == queueARN || statement.Resource == "*"
	condition := conditionStatement.Contains(eventbusSourceARN)
	return effect && principal && action && resource && condition
}

func queueNameFromURL(queueURL string) (string, error) {
	u, err := url.Parse(queueURL)
	if err != nil {
		return "", errors.Wrapf(err, "failed to parse url %q", queueURL)
	}

	pathParts := strings.Split(u.Path, "/")
	if len(pathParts) != 3 {
		return "", errors.Errorf("unexpected path length of %d for url %q", len(pathParts), queueURL)
	}

	return pathParts[2], nil
}

func queueNameFromARN(queueARN string) (string, error) {
	segments := strings.Split(queueARN, ":")
	if len(segments) != 6 || segments[2] != "sqs" {
		return "", errors.New("invalid queue arn")
	}
	return segments[5], nil
}
