package dynamo

import (
	"testing"
	"time"

	"code.justin.tv/extensions/configuration/services/main/data/model"
	"code.justin.tv/extensions/configuration/services/main/data/model/bad"
	"code.justin.tv/extensions/configuration/services/main/data/model/harness"
	"code.justin.tv/extensions/configuration/services/main/data/model/memory"
	"code.justin.tv/extensions/configuration/services/main/protocol"
	"code.justin.tv/gds/gds/golibs/dynamodb/lazy"
	"code.justin.tv/gds/gds/golibs/uuid"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
	"github.com/stretchr/testify/assert"
)

const integrationTestAccount = "twitch_ds_private"
const integrationPrefix = "configuration"
const integrationRegion = "us-west-2"

func createIntegrationGenerators(t *testing.T) (harness.BlockTrackerGenerator, harness.StoreGenerator) {
	if testing.Short() {
		t.Skip("Skipping integration tests in short mode")
	}

	dynamoSession, err := session.NewSession(aws.NewConfig())
	if err != nil {
		t.Fatal(t, "Couldn't create an AWS session for integration tests; use --short to skip", err)
	}

	return createGenerators(t, dynamodb.New(dynamoSession))
}

func createGenerators(t *testing.T, db dynamodbiface.DynamoDBAPI) (harness.BlockTrackerGenerator, harness.StoreGenerator) {
	storeGen := func(src uuid.Source) (model.Store, model.BlockTracker) {
		mem := memory.New(src)
		st := New(src, db, mem, integrationPrefix, time.Millisecond)
		st.EnableDataReset()
		st.ResetAllData()
		st.(*store).resetEnabled = 0
		return st, mem
	}

	trackerGen := func(src uuid.Source) model.BlockTracker {
		tr := New(src, db, nil, integrationPrefix, time.Millisecond)
		tr.EnableDataReset()
		tr.ResetAllData()
		tr.(*store).resetEnabled = 0
		return tr
	}

	return trackerGen, storeGen
}

func TestStore(t *testing.T) {
	t.Run("passes integration tests", func(t *testing.T) {
		tr, st := createIntegrationGenerators(t)
		harness.BlockTracker{}.RunTests(tr, false, t)
		harness.Store{}.RunTests(st, false, t)
	})

	t.Run("forwards db errors", func(t *testing.T) {
		tr, st := createGenerators(t, &mockDynamo{err: bad.ErrExpected})
		store, _ := st(uuid.NewSource())
		co, err := store.AsyncLoadCommon("env", "extID").Common()
		assert.Nil(t, co)
		assert.Equal(t, bad.ErrExpected, err)

		ch, err := store.AsyncLoadChannel("env", "extID", "chID").Channel()
		assert.Nil(t, ch)
		assert.Equal(t, bad.ErrExpected, err)

		co, err = store.LoadCommon("env", "extID")
		assert.Nil(t, co)
		assert.Equal(t, bad.ErrExpected, err)

		ch, err = store.LoadChannel("env", "extID", "chID")
		assert.Nil(t, ch)
		assert.Equal(t, bad.ErrExpected, err)

		assert.Equal(t, bad.ErrExpected, store.SaveCommon(model.NewCommon("env", "extID")))
		assert.Equal(t, bad.ErrExpected, store.SaveChannel(model.NewChannel("env", "extID", "chID")))

		// this call uses the tracker to persist its change
		assert.NoError(t, store.DeleteChannel("chID"))

		assert.Equal(t, bad.ErrExpected, store.MarkCommonPublished(model.NewCommon("env", "extID")))
		assert.Equal(t, bad.ErrExpected, store.MarkChannelPublished(model.NewChannel("env", "extID", "chID")))

		tracker := tr(uuid.NewSource())
		assert.Equal(t, bad.ErrExpected, tracker.Block("chID"))
		assert.Equal(t, bad.ErrExpected, tracker.Unblock("chID"))
		assert.Equal(t, bad.ErrExpected, tracker.OnDeletionFinished("chID"))

		active, err := tracker.IsBlocked("chID").Active()
		assert.False(t, active)
		assert.Equal(t, bad.ErrExpected, err)

		list, err := tracker.DeletionInProgress()
		assert.Empty(t, list)
		assert.Equal(t, bad.ErrExpected, err)
	})

	t.Run("forwards tracker errors", func(t *testing.T) {
		store := New(uuid.NewSource(), &mockDynamo{err: protocol.ErrUnimplemented}, bad.New(), "test_prefix", time.Millisecond)
		ch, err := store.AsyncLoadChannel("env", "extID", "chID").Channel()
		assert.Nil(t, ch)
		assert.Equal(t, bad.ErrExpected, err)

		ch, err = store.LoadChannel("env", "extID", "chID")
		assert.Nil(t, ch)
		assert.Equal(t, bad.ErrExpected, err)

		assert.Equal(t, bad.ErrExpected, store.DeleteChannel("chID"))
		assert.Equal(t, bad.ErrExpected, store.SaveChannel(model.NewChannel("env", "extID", "chID")))
	})

	t.Run("respects tracker blocking", func(t *testing.T) {
		tr := memory.New(uuid.NewSource())
		tr.Block("chID")
		store := New(uuid.NewSource(), &mockDynamo{err: protocol.ErrUnimplemented}, tr, "test_prefix", time.Millisecond)
		ch, err := store.AsyncLoadChannel("env", "extID", "chID").Channel()
		assert.Nil(t, ch)
		assert.NoError(t, err)

		ch, err = store.LoadChannel("env", "extID", "chID")
		assert.Nil(t, ch)
		assert.NoError(t, err)

		assert.Equal(t, protocol.ErrForbiddenByBroadcaster, store.SaveChannel(model.NewChannel("env", "extID", "chID")))
	})

	t.Run("forwards marshal errors", func(t *testing.T) {
		store := New(uuid.NewSource(), &mockDynamo{err: protocol.ErrUnimplemented}, bad.New(), "test_prefix", time.Millisecond).(*store)
		store.marshal = func(interface{}) (map[string]*dynamodb.AttributeValue, error) {
			return nil, lazy.ErrMigrationCorrupt
		}
		assert.Equal(t, lazy.ErrMigrationCorrupt, store.saveBlockRecord(&blockRecord{}))
		assert.Equal(t, lazy.ErrMigrationCorrupt, store.SaveCommon(model.NewCommon("env", "extID")))
		assert.Equal(t, lazy.ErrMigrationCorrupt, store.SaveChannel(model.NewChannel("env", "extID", "chID")))
	})
}
