package aws

import (
	"context"
	"fmt"
	"sort"
	"time"

	"github.com/aws/aws-sdk-go/service/secretsmanager"

	"github.com/aws/aws-sdk-go/aws/session"

	"code.justin.tv/video/clock"

	"code.justin.tv/amzn/StarfruitVault/key"

	"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
)

const (
	// Rotate daily
	DefaultKeyRotationTTL = 1 * 24 * time.Hour

	// Deactivate Keys older than 7 days
	DefaultKeyDeactivationTTL = 7 * 24 * time.Hour

	// keep 30 days worth of keys around
	DefaultKeyExpirationTTL = 30 * 24 * time.Hour

	// The minimum number of keys within the active window
	minActive = 1
)

type RotatorConfig struct {
	// Session is the AWS Session. If this is nil, this will create a new session
	Session *session.Session

	KeyRotationTTL     time.Duration
	KeyDeactivationTTL time.Duration
	KeyExpirationTTL   time.Duration

	EncryptionKeyType key.KeyType
	SignatureKeyType  key.KeyType

	PrivateID string
	PublicID  string

	// mocking API
	Clock          clock.Clock
	SecretsManager secretsmanageriface.SecretsManagerAPI
}

type rotator struct {
	RotatorConfig
	client client
}

func NewRotator(config *RotatorConfig) (*rotator, error) {
	if config == nil {
		return nil, fmt.Errorf("config cannot be nil")
	}

	if config.SecretsManager == nil {
		if config.Session == nil {
			config.Session = session.Must(session.NewSession())
		}
		config.SecretsManager = secretsmanager.New(config.Session)
	}

	if config.KeyRotationTTL == 0 {
		config.KeyRotationTTL = DefaultKeyRotationTTL
	}
	if config.KeyDeactivationTTL == 0 {
		config.KeyDeactivationTTL = DefaultKeyDeactivationTTL
	}
	if config.KeyExpirationTTL == 0 {
		config.KeyExpirationTTL = DefaultKeyExpirationTTL
	}
	if config.PrivateID == "" {
		return nil, fmt.Errorf("private id arn cannot be empty")
	}
	if config.Clock == nil {
		config.Clock = clock.New()
	}

	switch config.EncryptionKeyType {
	case key.KeyTypeSharedSecretBox:
	case key.KeyTypeSharedAESGCM:
	default:
		return nil, fmt.Errorf("encryption key type not support for rotation %s", config.EncryptionKeyType)
	}

	switch config.SignatureKeyType {
	case key.KeyTypeNoop:
	case key.KeyTypePrivateEd25519:
		if config.PrivateID == "" {
			return nil, fmt.Errorf("public id arn cannot be empty for public/private keys")
		}
	default:
		return nil, fmt.Errorf("signature kkey type not support for rotation %s", config.EncryptionKeyType)
	}

	return &rotator{*config, client{config.SecretsManager}}, nil
}

func (r *rotator) Rotate(ctx context.Context) error {
	var set key.KeySet
	err := r.client.Fetch(ctx, r.PrivateID, &set)
	if err != nil {
		return err
	}

	sort.Slice(set.Keys, func(i, j int) bool {
		return set.Keys[i].GeneratedAt.Before(set.Keys[j].GeneratedAt)
	})

	if r.rotationNeeded(&set) {
		pruneOverTTL(r.Clock.Now(), r.KeyExpirationTTL, &set)
		deactivateOverTTL(r.Clock.Now(), r.KeyDeactivationTTL, &set)

		err = r.rotate(&set)
		if err != nil {
			return fmt.Errorf("failed to rotate keys: %w", err)
		}
	}
	return nil
}

func (r *rotator) rotate(privateSet *key.KeySet) error {
	private, err := key.GenerateKey(&key.GenerateOptions{
		KeyEpoch:          findNextKeyEpoch(privateSet),
		GeneratedAt:       r.Clock.Now(),
		EncryptionKeyType: r.EncryptionKeyType,
		SignatureKeyType:  r.SignatureKeyType,
	})
	if err != nil {
		return err
	}

	privateSet.Keys = append(privateSet.Keys, private)

	// Order matters here. We write the public key before the private key to ensure we don't get into an inconsistent state
	if r.PublicID != "" {
		publicSet := &key.KeySet{
			Keys: make([]*key.Key, 0, len(privateSet.Keys)),
		}
		for _, sk := range privateSet.Keys {
			public, err := sk.ConsumerKey()
			if err != nil {
				return err
			}
			publicSet.Keys = append(publicSet.Keys, public)
		}

		err = r.client.Put(context.Background(), r.PublicID, publicSet)
		if err != nil {
			return err
		}
	}

	err = r.client.Put(context.Background(), r.PrivateID, privateSet)
	if err != nil {
		return err
	}
	return nil
}

func pruneOverTTL(now time.Time, ttl time.Duration, keys *key.KeySet) {
	var prunable, active int
	for _, serializedKey := range keys.Keys {
		if serializedKey.GeneratedAt.Add(ttl).Before(now) {
			prunable++
		} else if serializedKey.GeneratedAt.Before(now) {
			active++
		}
	}
	// leave some active
	if active < minActive {
		return
	}
	keys.Keys = keys.Keys[prunable:]
}

func deactivateOverTTL(now time.Time, ttl time.Duration, keys *key.KeySet) {
	// leave some active
	if len(keys.Keys) <= minActive {
		return
	}

	for _, serializedKey := range keys.Keys[:len(keys.Keys)-minActive] {
		if serializedKey.GeneratedAt.Add(ttl).Before(now) {
			serializedKey.Deactivated = true
		}
	}
}

func findNextKeyEpoch(set *key.KeySet) uint32 {
	if set == nil || set.Keys == nil {
		return 0
	}

	var nextEpoch uint32
	for _, k := range set.Keys {
		if nextEpoch <= k.KeyEpoch {
			nextEpoch = k.KeyEpoch + 1
		}
	}
	return nextEpoch
}

func (r *rotator) rotationNeeded(set *key.KeySet) bool {
	rotateAfter := r.Clock.Now().Add(-r.KeyRotationTTL)
	for _, serializedKey := range set.Keys {
		if serializedKey.GeneratedAt.After(rotateAfter) {
			return false
		}
	}
	return true
}
