package cachedl2authentication

import (
	"context"
	"errors"
	"fmt"
	"time"

	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCaller/internal/authentication/authenticationiface"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCaller/internal/cacheitem"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCaller/internal/s2s2err"
	"code.justin.tv/video/metrics-middleware/v2/operation"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
)

const (
	RFC3339_NANO_FORMAT  = "2006-01-02 15:04:05.999999999 -0700 MST"
	UNIX_SEC_TIME_FORMAT = "20060102150405"
)

/*
 CachedL2Authentications is a cached implementation of
 authentication.Authentications using DynamoDB
*/
type CachedL2Authentications struct {
	authentications    authenticationiface.AuthenticationsAPI
	staleTimeout       time.Duration
	expirationTimeout  time.Duration
	refreshRateLimiter *time.Ticker
	logger             *s2s2err.Logger
	operationStarter   *operation.Starter
	dynamodbClient     dynamodbiface.DynamoDBAPI
	cacheTableName     string
}

// New returns a new CachedL2Authentications client
func New(
	authentications authenticationiface.AuthenticationsAPI,
	staleTimeout time.Duration,
	expirationTimeout time.Duration,
	refreshRateLimiter *time.Ticker,
	logger *s2s2err.Logger,
	operationStarter *operation.Starter,
	dynamodbClient dynamodbiface.DynamoDBAPI,
	cacheTableName string,
) *CachedL2Authentications {
	return &CachedL2Authentications{
		authentications:    authentications,
		staleTimeout:       staleTimeout,
		expirationTimeout:  expirationTimeout,
		refreshRateLimiter: refreshRateLimiter,
		logger:             logger,
		operationStarter:   operationStarter,
		dynamodbClient:     dynamodbClient,
		cacheTableName:     cacheTableName,
	}
}

type Item struct {
	Token        string
	CreationTime string
	RefreshTime  int
}

// Authenticate authenticates a request from DynamoDB cache if possible
func (a *CachedL2Authentications) Authenticate(ctx context.Context, audienceHost string) ([]byte, error) {
	cachedValue, err := a.fetchFromL2Cache(audienceHost)
	if err != nil {
		freshValue, err := a.fetchFromKMS(ctx, audienceHost)
		if err != nil {
			return nil, fmt.Errorf("token request failed with error %w", err)
		}
		updateErr := a.updateL2Cache(audienceHost, freshValue)
		if updateErr != nil {
			// If err is due to condition expression, it means another client already did the put
			// So we need to ignore the fresh token and use the one in L2 by fetching it again
			if putErr, ok := updateErr.(awserr.RequestFailure); ok && putErr.Code() == "ConditionalCheckFailedException" {
				a.logger.LogError(updateErr)
				// Need to fetch token from L2 again
				existingValue, err := a.fetchFromL2Cache(audienceHost)
				if err != nil {
					return nil, fmt.Errorf("token request failed with error %w", err)
				} else {
					return existingValue.Token, nil
				}
			} else {
				return nil, fmt.Errorf("token request failed with error %w", updateErr)
			}
		}
		return freshValue.Token, nil
	}
	if !cachedValue.IsStale() {
		return cachedValue.Token, nil
	}
	// Cached value is stale, however we still can use it but also need to renew the cache
	// keep the old value and try to read from KMS. If success use that and update L2 with the fresh value
	newValue, err := a.fetchFromKMS(ctx, audienceHost)
	if err != nil {
		// Call to KMS failed, use the stale value
		return cachedValue.Token, nil
	}
	// fresh token is exist, trying to update L2
	if err = a.updateL2Cache(audienceHost, newValue); err != nil {
		// In this path we already have one stale token. We fetched a new one, tried to update L2 and it failed.
		// Using the stale value.
		return cachedValue.Token, nil
	}
	return newValue.Token, nil
}

func (a *CachedL2Authentications) fetchFromKMS(ctx context.Context, calleeHost string) (*cacheitem.CacheItem, error) {
	token, err := a.authentications.Authenticate(ctx, calleeHost)
	if err != nil {
		return nil, err
	}
	return &cacheitem.CacheItem{
		Token:             token,
		CreationTime:      time.Now(),
		StaleTimeout:      a.staleTimeout,
		ExpirationTimeout: a.expirationTimeout,
	}, nil
}

func (a *CachedL2Authentications) fetchFromL2Cache(calleeHost string) (*cacheitem.CacheItem, error) {
	dynamoGetInput := &dynamodb.GetItemInput{
		Key: map[string]*dynamodb.AttributeValue{
			"host_hash": {
				S: aws.String(calleeHost),
			},
		},
		TableName: aws.String(a.cacheTableName),
	}
	result, err := a.dynamodbClient.GetItem(dynamoGetInput)
	if err != nil {
		return nil, err
	}
	if result.Item == nil {
		msg := "l2 cache miss, no data in cache"
		a.logger.LogError(errors.New(msg))
		return nil, errors.New(msg)
	}
	item := Item{}
	err = dynamodbattribute.UnmarshalMap(result.Item, &item)
	if err != nil {
		msg := "l2 cache miss, fail unmarshaling"
		a.logger.LogError(errors.New(msg))
		return nil, errors.New(msg)
	}
	creationTime, _ := time.Parse(RFC3339_NANO_FORMAT, item.CreationTime)
	value := &cacheitem.CacheItem{
		Token:             []byte(item.Token),
		CreationTime:      creationTime,
		StaleTimeout:      a.staleTimeout,
		ExpirationTimeout: a.expirationTimeout,
	}
	if !value.IsExpired() {
		return value, nil
	}
	msg := "l2 cache miss, token is expired"
	a.logger.LogError(errors.New(msg))
	return nil, errors.New(msg)
}

func (a *CachedL2Authentications) updateL2Cache(calleeHost string, value *cacheitem.CacheItem) (err error) {
	dynamoPutInput := &dynamodb.PutItemInput{
		Item: map[string]*dynamodb.AttributeValue{
			"host_hash": {
				S: aws.String(calleeHost),
			},
			"refreshTime": {
				N: aws.String(time.Now().Add(a.staleTimeout).Format(UNIX_SEC_TIME_FORMAT)),
			},
			"refreshTimeInRFC3399": {
				S: aws.String(time.Now().Add(a.staleTimeout).Format(RFC3339_NANO_FORMAT)),
			},
			"token": {
				S: aws.String(string(value.Token)),
			},
			"creationTime": {
				S: aws.String(string(value.CreationTime.Format(RFC3339_NANO_FORMAT))),
			},
		},
		TableName:    aws.String(a.cacheTableName),
		ReturnValues: aws.String("ALL_OLD"),
	}

	dynamoPutInput.SetExpressionAttributeValues(map[string]*dynamodb.AttributeValue{
		":current_time": {
			N: aws.String(time.Now().Format(UNIX_SEC_TIME_FORMAT)),
		},
	})
	dynamoPutInput.SetConditionExpression("attribute_not_exists(refreshTime) or attribute_exists(refreshTime) and refreshTime < :current_time")
	_, err = a.dynamodbClient.PutItem(dynamoPutInput)
	if err != nil {
		return err
	}
	return nil
}
