package kms

import (
	"context"

	"code.justin.tv/eventbus/controlplane/internal/arn"
	"code.justin.tv/eventbus/schema/pkg/eventbus/authorization"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/service/kms"
	"github.com/pkg/errors"
)

var encryptOperations = []string{
	"Encrypt",
	"GenerateDataKey",
	"GenerateDataKeyWithoutPlaintext",
	"GenerateDataKeyPair",
	"GenerateDataKeyPairWithoutPlaintext",
}
var decryptOperations = []string{
	"Decrypt",
}

var fullAccessOperations = append(encryptOperations, decryptOperations...)

type Manager struct {
	keyID  string
	client KMSAPI
}

func NewManager(sess client.ConfigProvider, keyID string) *Manager {
	return &Manager{
		keyID:  keyID,
		client: kms.New(sess),
	}
}

func (km *Manager) GrantEncryptionAtRest(ctx context.Context, iamRoleARN string) (string, error) {
	acctID, err := arn.AccountID(iamRoleARN)
	if err != nil {
		return "", errors.Wrap(err, "could not determine aws account id")
	}

	return km.grantARN(ctx, iamRoleARN, arn.IAMRootARN(acctID), fullAccessOperations, nil)
}

func (km *Manager) GrantAuthorizedFieldPublisher(ctx context.Context, iamRoleARN, eventType, environment string) (string, error) {
	constraints := &kms.GrantConstraints{
		EncryptionContextSubset: aws.StringMap(map[string]string{
			authorization.EventType:   eventType,
			authorization.Environment: environment,
		}),
	}
	return km.grantARN(ctx, iamRoleARN, iamRoleARN, encryptOperations, constraints)
}

func (km *Manager) GrantAuthorizedFieldSubscriber(ctx context.Context, iamRoleARN, eventType, environment, messageName, fieldName string) (string, error) {
	constraints := &kms.GrantConstraints{
		EncryptionContextSubset: aws.StringMap(map[string]string{
			authorization.EventType:   eventType,
			authorization.Environment: environment,
			authorization.MessageName: messageName,
			authorization.FieldName:   fieldName,
		}),
	}
	return km.grantARN(ctx, iamRoleARN, iamRoleARN, decryptOperations, constraints)
}

func (km *Manager) grantARN(ctx context.Context, grantName string, iamRoleARN string, operations []string, constraints *kms.GrantConstraints) (string, error) {
	output, err := km.client.CreateGrantWithContext(ctx, &kms.CreateGrantInput{
		Constraints:      constraints,
		GranteePrincipal: aws.String(iamRoleARN),
		Name:             aws.String(grantName),
		KeyId:            aws.String(km.keyID),
		Operations:       aws.StringSlice(operations),
	})
	if err != nil {
		return "", err
	}
	return aws.StringValue(output.GrantId), nil
}

func (km *Manager) Revoke(ctx context.Context, grantID string) error {
	_, err := km.client.RevokeGrantWithContext(ctx, &kms.RevokeGrantInput{
		GrantId: aws.String(grantID),
		KeyId:   aws.String(km.keyID),
	})
	return err
}

func (km *Manager) AllGrants(ctx context.Context) ([]*kms.GrantListEntry, error) {
	kmsGrants := []*kms.GrantListEntry{}
	grantFetcher := func(results *kms.ListGrantsResponse, lastPage bool) bool {
		kmsGrants = append(kmsGrants, results.Grants...)
		return !lastPage
	}

	err := km.client.ListGrantsPagesWithContext(ctx, &kms.ListGrantsInput{
		KeyId: aws.String(km.keyID),
	}, grantFetcher)
	if err != nil {
		return nil, err
	}

	return kmsGrants, nil
}

// Attempts to find a grant given a grant ID. If the grant is not found, returns nil with no error.
func (km *Manager) FindByGrantID(ctx context.Context, grantID string) (*kms.GrantListEntry, error) {
	var grant *kms.GrantListEntry
	handler := func(output *kms.ListGrantsResponse, lastPage bool) bool {
		for _, g := range output.Grants {
			if aws.StringValue(g.GrantId) == grantID {
				grant = g
				return false
			}
		}
		return !lastPage
	}

	err := km.client.ListGrantsPagesWithContext(ctx, &kms.ListGrantsInput{
		KeyId: aws.String(km.keyID),
	}, handler)

	return grant, err
}
