package cachetoken

import (
	"context"

	"code.justin.tv/amzn/TwitchS2S2/internal/token"
)

// New returns a new token cache
func New(
	inner token.Tokens,
	cacheKey func(*token.Options) string,
	cacheValueStale func(*token.Options, *token.Token) bool,
	cacheValueExpired func(*token.Options, *token.Token) bool,
) *Cache {
	return &Cache{
		inner:             inner,
		cacheKey:          cacheKey,
		cacheValueStale:   cacheValueStale,
		cacheValueExpired: cacheValueExpired,
		values:            make(map[string]*token.Token),
		lock:              newShardedLock(),
	}
}

// Cache is a wrapper around token.Tokens that caches tokens based on a cache
// key, cache staleness policy, and a cache expiration policy.
//
// If a cache is stale but the underlying client returns an error, the stale
// cache will be returned.
//
// If a cache is expired but the underlying client returns an error, the error
// will be surfaced.
type Cache struct {
	inner             token.Tokens
	cacheKey          func(*token.Options) string
	cacheValueStale   func(*token.Options, *token.Token) bool
	cacheValueExpired func(*token.Options, *token.Token) bool

	values map[string]*token.Token
	lock   *shardedLock
}

// Token implements token.Tokens
func (c *Cache) Token(ctx context.Context, options *token.Options) (*token.Token, error) {
	var cacheVal, val *token.Token
	var cacheValOK, expired, stale bool
	var err error
	cacheKey := c.cacheKey(options)
	lock := c.lock.Shard(cacheKey)

	func() {
		lock.RLock()
		defer lock.RUnlock()

		cacheVal, cacheValOK = c.values[cacheKey]

		stale = true
		expired = true

		if cacheVal != nil {
			stale = c.cacheValueStale(options, cacheVal)
			expired = c.cacheValueExpired(options, cacheVal)
		}
	}()

	if cacheValOK && !stale && !expired {
		return cacheVal, nil
	}

	val, err = func() (*token.Token, error) {
		lock.Lock()
		defer lock.Unlock()

		val, err = c.inner.Token(ctx, options)
		if err != nil {
			if !expired {
				// return the cached value on error since it's likely still valid unless
				// expired.
				return cacheVal, nil
			}
			return nil, err
		}
		c.values[cacheKey] = val
		return val, nil
	}()

	return val, err
}
