package oidccache

import (
	"context"
	"crypto/ecdsa"
	"crypto/elliptic"
	"errors"
	"math/big"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchS2S2/internal/oidc"
	"code.justin.tv/amzn/TwitchS2S2/internal/oidc/oidccache/mocks"
	"code.justin.tv/amzn/TwitchS2SJWTAlgorithms/es256"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestCache(t *testing.T) {
	ctx := context.Background()

	t.Run("Configuration", func(t *testing.T) {
		ct := newCacheTest()
		defer ct.Teardown(t)

		cfg := &oidc.Configuration{
			IssuerID:              "ISSUERID",
			AuthorizationEndpoint: "AUTHORIZATIONENDPOINT",
			TokenEndpoint:         "TOKENENDPOINT",
			JwksURI:               "JWKSURI",
		}

		ct.OIDC.On("Configuration").Return(cfg)

		assert.Equal(t, cfg, ct.Cache.Configuration())
	})

	t.Run("ValidationKeys", func(t *testing.T) {
		ct := newCacheTest()
		defer ct.Teardown(t)

		const keyID = "KEYID"
		key := &es256.PublicKey{PublicKey: &ecdsa.PublicKey{Curve: elliptic.P256(), X: big.NewInt(11), Y: big.NewInt(13)}, KeyID: keyID}
		keys := map[string]oidc.ValidationKey{keyID: key}
		exp := time.Hour

		ct.OIDC.On("ValidationKeys", ctx).Return(keys, exp, nil)

		res, dt, err := ct.Cache.ValidationKeys(ctx)
		require.NoError(t, err)
		assert.Equal(t, exp, dt)
		assert.Equal(t, keys, res)
	})

	t.Run("ValidationKey", func(t *testing.T) {
		const keyID = "KEYID"
		key := &es256.PublicKey{PublicKey: &ecdsa.PublicKey{Curve: elliptic.P256(), X: big.NewInt(11), Y: big.NewInt(13)}, KeyID: keyID}

		t.Run("success with refresh and cache", func(t *testing.T) {
			ct := newCacheTest()
			defer ct.Teardown(t)

			ct.OIDC.On("ValidationKeys", ctx).
				Return(map[string]oidc.ValidationKey{keyID: key}, time.Hour, nil).
				Once()

			res, err := ct.Cache.ValidationKey(ctx, keyID)
			require.NoError(t, err)
			assert.Equal(t, key, res)

			res, err = ct.Cache.ValidationKey(ctx, keyID)
			require.NoError(t, err)
			assert.Equal(t, key, res, "from cache")
			assert.InDelta(t, time.Hour, time.Until(ct.Cache.validationKeysExpired), float64(time.Minute))
			assert.InDelta(t, staleDuration, time.Until(ct.Cache.validationKeysStale), float64(time.Minute))
		})

		t.Run("error cache miss", func(t *testing.T) {
			ct := newCacheTest()
			defer ct.Teardown(t)

			ct.OIDC.On("ValidationKeys", ctx).
				Return(map[string]oidc.ValidationKey{"somethingelse": key}, time.Hour, nil).
				Once()

			_, err := ct.Cache.ValidationKey(ctx, keyID)
			assert.Equal(t, &oidc.ErrUnknownSigningKey{KID: keyID}, err)

			_, err = ct.Cache.ValidationKey(ctx, keyID)
			assert.Equal(t, &oidc.ErrUnknownSigningKey{KID: keyID}, err, "from cache")
		})

		t.Run("error on refresh when stale", func(t *testing.T) {
			ct := newCacheTest()
			defer ct.Teardown(t)

			refreshErr := errors.New("REFRESHERR")

			ct.Cache.validationKeys = map[string]oidc.ValidationKey{keyID: key}
			ct.Cache.validationKeysStale = time.Now().Add(-time.Hour)
			ct.Cache.validationKeysExpired = time.Now().Add(time.Hour)

			ct.OIDC.On("ValidationKeys", ctx).
				Return(nil, time.Duration(0), refreshErr).
				Once()

			res, err := ct.Cache.ValidationKey(ctx, keyID)
			require.NoError(t, err)
			assert.Equal(t, key, res)
		})

		t.Run("error on refresh when expired", func(t *testing.T) {
			ct := newCacheTest()
			defer ct.Teardown(t)

			ct.Cache.validationKeys = map[string]oidc.ValidationKey{keyID: key}
			ct.Cache.validationKeysStale = time.Now().Add(-time.Hour)
			ct.Cache.validationKeysExpired = time.Now().Add(-time.Hour)
			refreshErr := errors.New("REFRESHERR")

			ct.OIDC.On("ValidationKeys", ctx).
				Return(nil, time.Duration(0), refreshErr).
				Once()

			_, err := ct.Cache.ValidationKey(ctx, keyID)
			assert.Equal(t, refreshErr, err, "from cache")
		})
	})
}

func newCacheTest() *cacheTest {
	oidc := new(mocks.OIDCAPI)
	return &cacheTest{
		Cache: &Cache{OIDC: oidc},
		OIDC:  oidc,
	}
}

type cacheTest struct {
	Cache *Cache
	OIDC  *mocks.OIDCAPI
}

func (ct *cacheTest) Teardown(t *testing.T) {
	ct.OIDC.AssertExpectations(t)
}
