package sandstorm

import (
	"net/http"
	"sync/atomic"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/sts"

	"code.justin.tv/common/goauthorization"
	"code.justin.tv/common/jwt"
	"code.justin.tv/systems/sandstorm/manager"
)

const (
	// note : 900 seconds is a minimum imposed by manager
	sandstormDuration     = 900 * time.Second
	sandstormExpiryWindow = 10 * time.Second
	sandstormCooldown     = 5 * time.Minute
)

// sandstormDecoder implements IDecoder using a ES256 PEM fetched
// over the Sandstorm API.
type sandstormDecoder struct {
	region      string
	roleArn     string
	table       string
	keyID       string
	secretName  string
	audience    string
	issuer      string
	mgr         *manager.Manager
	mgrExpires  time.Time
	decoder     goauthorization.Decoder
	failCount   int32
	isUpdating  int32
	lastUpdated time.Time
}

// NewDecoder turns a Sandstorm manager and secret description for
// Cartman into an IDecoder source.
func NewDecoder(region, roleArn, table, keyID, secretName, audience, issuer string) (goauthorization.Decoder, error) {
	decoder := &sandstormDecoder{
		region:     region,
		roleArn:    roleArn,
		table:      table,
		keyID:      keyID,
		secretName: secretName,
		audience:   audience,
		issuer:     issuer,
	}
	err := decoder.updateDecoder()
	if err != nil {
		return nil, err
	}
	return decoder, nil
}

func (s *sandstormDecoder) Decode(ser string) (*goauthorization.AuthorizationToken, error) {
	token, err := s.decoder.Decode(ser)
	s.tryUpdateDecoder(err)
	return token, err
}

func (s *sandstormDecoder) ParseToken(r *http.Request) (*goauthorization.AuthorizationToken, error) {
	token, err := s.decoder.ParseToken(r)
	s.tryUpdateDecoder(err)
	return token, err
}

func (s *sandstormDecoder) Validate(t *goauthorization.AuthorizationToken, c goauthorization.CapabilityClaims) error {
	err := s.decoder.Validate(t, c)
	s.tryUpdateDecoder(err)
	return err
}

func (s *sandstormDecoder) updateDecoder() error {
	if s.lastUpdated.Add(sandstormCooldown).After(time.Now()) || atomic.SwapInt32(&s.isUpdating, 1) == 1 {
		return nil
	}
	defer atomic.StoreInt32(&s.isUpdating, 0)

	err := s.updateManager()
	if err != nil {
		return err
	}
	pem, err := s.mgr.Get(s.secretName)
	if err != nil {
		return err
	}
	s.decoder, err = goauthorization.NewDecoder("ES256", s.audience, s.issuer, pem.Plaintext)
	if err != nil {
		return err
	}
	s.lastUpdated = time.Now()
	atomic.StoreInt32(&s.failCount, 0)
	return nil
}

func (s *sandstormDecoder) tryUpdateDecoder(err error) {
	if err == jwt.ErrInvalidECSignature {
		// if we see a series of invalid signatures, try to update the secret from the manager
		if atomic.LoadInt32(&s.isUpdating) == 0 && atomic.AddInt32(&s.failCount, 1) > 10 {
			// todo : hook this up to rollbar
			_ = s.updateDecoder()
		}
	}
}

func (s *sandstormDecoder) updateManager() error {
	if s.mgr != nil && s.mgrExpires.After(time.Now()) {
		return nil
	}
	s.mgrExpires = time.Now().Add(sandstormDuration)

	mgr, err := NewManager(s.region, s.roleArn, s.table, s.keyID)
	if err != nil {
		return err
	}

	s.mgr = mgr
	return nil
}

func NewManager(region, roleArn, table, keyID string) (*manager.Manager, error) {
	// generate cross-realm credentials
	awsConfig := &aws.Config{Region: aws.String(region)}
	session, err := session.NewSession(awsConfig)
	if err != nil {
		return nil, err
	}

	arp := &stscreds.AssumeRoleProvider{
		Duration:     sandstormDuration,
		ExpiryWindow: sandstormExpiryWindow,
		RoleARN:      roleArn,
		Client:       sts.New(session),
	}
	awsConfig.WithCredentials(credentials.NewCredentials(arp))

	mgr := manager.New(manager.Config{
		AWSConfig: awsConfig,
		TableName: table,
		KeyID:     keyID,
	})
	return mgr, nil
}
