package harness

import (
	"strconv"
	"testing"
	"time"

	"code.justin.tv/gds/gds/golibs/event"
	"code.justin.tv/gds/gds/golibs/uuid"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"code.justin.tv/extensions/fulton-configuration/data/model"
	"code.justin.tv/extensions/fulton-configuration/protocol"
	"code.justin.tv/extensions/fulton-configuration/protocol/messages"
)

type fakeUUID struct {
	counter int64
}

// Next implements uuid.Source
func (f *fakeUUID) Next() (string, error) {
	f.counter++
	return strconv.FormatInt(f.counter, 16), nil
}

// StoreGenerator is an input for the test harness. Each call should generate
// a store that doesn't collide with the others. It is expected that each call
// to this function returns a clean slate; if RunTest() allowParallel=false
// this function can clean up and return the same instance each time.
type StoreGenerator func(uuid uuid.Source) (model.Store, model.BlockTracker)

type Store struct{}

// RunTests checks a store implementation for compliance with the reference
// implementation
func (s Store) RunTests(gen StoreGenerator, allowParallel bool, t *testing.T) {
	src := new(fakeUUID)
	s.checkInitialState(gen(src))(t)
	s.checkSaveCommon(gen(src))(t)
	s.checkSaveChannel(gen(src))(t)
	s.checkDeleteChannel(gen(src))(t)
	s.checkResetAllData(gen(src))(t)
	s.checkChannelWhenDeleted(gen(src))(t)
}

func (s *Store) checkInitialState(store model.Store, trk model.BlockTracker) func(t *testing.T) {
	return func(t *testing.T) {
		common, err := store.LoadCommon("env", "ext")
		assert.Nil(t, common)
		assert.NoError(t, err)

		channel, err := store.LoadChannel("env", "ext", "ch")
		assert.Nil(t, channel)
		assert.NoError(t, err)

		pCommon := store.AsyncLoadCommon("env", "ext")
		common, err = pCommon.Common()
		assert.Nil(t, common)
		assert.NoError(t, err)

		pChannel := store.AsyncLoadChannel("env", "ext", "ch")
		channel, err = pChannel.Channel()
		assert.Nil(t, channel)
		assert.NoError(t, err)
	}
}

func (s *Store) checkSaveCommon(store model.Store, trk model.BlockTracker) func(t *testing.T) {
	return func(t *testing.T) {
		common := model.NewCommon("env", "id")
		common2 := model.NewCommon("env", "id")
		content := "record"
		common2.Global = protocol.NewRecord("3", &content)

		// update - old record not found
		common.ConcurrencyUUID = "old"
		assert.Equal(t, protocol.ErrConcurrency, store.SaveCommon(common))
		assert.Equal(t, protocol.ErrConcurrency, store.MarkCommonPublished(common))
		found, err := store.LoadCommon(common.Environment, common.ExtensionID)
		assert.Nil(t, found)
		assert.NoError(t, err)

		// create - success
		common.ConcurrencyUUID = ""
		assert.NoError(t, store.SaveCommon(common))
		assert.Equal(t, protocol.ErrConcurrency, store.MarkCommonPublished(common))
		s.verifyCommonUUID(store, common, t)

		// create - blocked by old record, which remains
		assert.Equal(t, protocol.ErrConcurrency, store.SaveCommon(common2))
		s.verifyCommonUUID(store, common, t)

		// update - wrong uuid, no change
		common2.ConcurrencyUUID = "wrong"
		assert.Equal(t, protocol.ErrConcurrency, store.SaveCommon(common2))
		s.verifyCommonUUID(store, common, t)

		// update - success
		common2.ConcurrencyUUID = common.ConcurrencyUUID
		assert.NoError(t, store.SaveCommon(common2))
		assert.Equal(t, protocol.ErrConcurrency, store.MarkCommonPublished(common2))
		s.verifyCommonUUID(store, common2, t)

		// update with messages - success
		common2.Messages = []event.Message{messages.NewConfigMessage(1, messages.OnSet)}
		assert.NoError(t, store.SaveCommon(common2))
		assert.NoError(t, store.MarkCommonPublished(common2))
		assert.Equal(t, protocol.ErrConcurrency, store.MarkCommonPublished(common2))
		s.verifyCommonUUID(store, common2, t)

		// verify data was recorded
		out, err := store.LoadCommon(common2.Environment, common2.ExtensionID)
		assert.NoError(t, err)
		assert.Equal(t, out.Global, common2.Global)
	}
}

func (s *Store) checkSaveChannel(store model.Store, trk model.BlockTracker) func(t *testing.T) {
	return func(t *testing.T) {
		channel := model.NewChannel("env", "id", "ch")
		channel2 := model.NewChannel("env", "id", "ch")
		content := "record"
		channel2.Broadcaster = protocol.NewRecord("4", &content)
		channel2.Developer = protocol.NewRecord("7", &content)

		// update - old record not found
		channel.ConcurrencyUUID = "old"
		assert.Equal(t, protocol.ErrConcurrency, store.SaveChannel(channel))
		assert.Equal(t, protocol.ErrConcurrency, store.MarkChannelPublished(channel))
		found, err := store.LoadChannel(channel.Environment, channel.ExtensionID, channel.ChannelID)
		assert.Nil(t, found)
		assert.NoError(t, err)

		// create - success
		channel.ConcurrencyUUID = ""
		assert.NoError(t, store.SaveChannel(channel))
		assert.Equal(t, protocol.ErrConcurrency, store.MarkChannelPublished(channel))
		s.verifyChannelUUID(store, channel, t)

		// create - blocked by old record, which remains
		assert.Equal(t, protocol.ErrConcurrency, store.SaveChannel(channel2))
		s.verifyChannelUUID(store, channel, t)

		// update - wrong uuid, no change
		channel2.ConcurrencyUUID = "wrong"
		assert.Equal(t, protocol.ErrConcurrency, store.SaveChannel(channel2))
		s.verifyChannelUUID(store, channel, t)

		// update - success
		channel2.ConcurrencyUUID = channel.ConcurrencyUUID
		assert.NoError(t, store.SaveChannel(channel2))
		assert.Equal(t, protocol.ErrConcurrency, store.MarkChannelPublished(channel2))
		s.verifyChannelUUID(store, channel2, t)

		// update with messages - success
		channel2.Messages = []event.Message{messages.NewConfigMessage(1, messages.OnSet)}
		assert.NoError(t, store.SaveChannel(channel2))
		assert.NoError(t, store.MarkChannelPublished(channel2))
		assert.Equal(t, protocol.ErrConcurrency, store.MarkChannelPublished(channel2))
		s.verifyChannelUUID(store, channel2, t)

		// verify data was recorded
		out, err := store.LoadChannel(channel2.Environment, channel2.ExtensionID, channel2.ChannelID)
		assert.NoError(t, err)
		assert.Equal(t, out.Broadcaster, channel2.Broadcaster)
		assert.Equal(t, out.Developer, channel2.Developer)
	}
}

func (s *Store) checkDeleteChannel(store model.Store, trk model.BlockTracker) func(t *testing.T) {
	return func(t *testing.T) {
		channel := model.NewChannel("env", "id", "ch")
		channel2 := model.NewChannel("env", "id2", "ch")
		require.NoError(t, store.SaveChannel(channel))
		require.NoError(t, store.SaveChannel(channel2))

		assert.NoError(t, store.DeleteChannel(channel.ChannelID))

		ch, err := store.LoadChannel(channel.Environment, channel.ExtensionID, channel.ChannelID)
		assert.Nil(t, ch)
		assert.NoError(t, err)

		ch, err = store.LoadChannel(channel2.Environment, channel2.ExtensionID, channel2.ChannelID)
		assert.Nil(t, ch)
		assert.NoError(t, err)

		assert.Equal(t, protocol.ErrForbiddenByBroadcaster, store.SaveChannel(channel))
		assert.Equal(t, protocol.ErrForbiddenByBroadcaster, store.SaveChannel(channel2))
	}
}

func (s *Store) checkResetAllData(store model.Store, trk model.BlockTracker) func(t *testing.T) {
	return func(t *testing.T) {
		channel := model.NewChannel("env", "id", "ch")
		common := model.NewCommon("env", "id")
		require.NoError(t, store.SaveChannel(channel))
		require.NoError(t, store.SaveCommon(common))

		assert.False(t, store.IsResetEnabled())
		assert.Equal(t, protocol.ErrUnavailable, store.ResetAllData())

		store.EnableDataReset()
		assert.True(t, store.IsResetEnabled())
		assert.NoError(t, store.ResetAllData())

		ch, err := store.LoadChannel("env", "id", "ch")
		assert.Nil(t, ch)
		assert.NoError(t, err)

		c, err := store.LoadCommon("env", "id")
		assert.Nil(t, c)
		assert.NoError(t, err)
	}
}

func (s *Store) checkChannelWhenDeleted(store model.Store, trk model.BlockTracker) func(t *testing.T) {
	return func(t *testing.T) {
		c := model.NewChannel("env", "id", "ch")
		require.NoError(t, store.SaveChannel(c))

		ch, err := store.LoadChannel(c.Environment, c.ExtensionID, c.ChannelID)
		assert.NotNil(t, ch)
		assert.NoError(t, err)

		require.NoError(t, trk.Block(c.ChannelID))

		ch, err = store.LoadChannel(c.Environment, c.ExtensionID, c.ChannelID)
		assert.Nil(t, ch)
		assert.NoError(t, err)

		assert.Equal(t, protocol.ErrForbiddenByBroadcaster, store.SaveChannel(c))
	}
}

func (s *Store) verifyChannelUUID(store model.Store, channel *model.Channel, t *testing.T) {
	time.Sleep(2 * time.Nanosecond) // flush nanosecond cache
	out, err := store.LoadChannel(channel.Environment, channel.ExtensionID, channel.ChannelID)
	assert.NoError(t, err)
	require.NotNil(t, out)
	assert.Equal(t, channel.ConcurrencyUUID, out.ConcurrencyUUID, "Concurrency mismatch")
	out, err = store.AsyncLoadChannel(channel.Environment, channel.ExtensionID, channel.ChannelID).Channel()
	assert.NoError(t, err)
	require.NotNil(t, out)
	assert.Equal(t, channel.ConcurrencyUUID, out.ConcurrencyUUID, "Concurrency mismatch")
}

func (s *Store) verifyCommonUUID(store model.Store, common *model.Common, t *testing.T) {
	time.Sleep(2 * time.Nanosecond) // flush nanosecond cache
	out, err := store.LoadCommon(common.Environment, common.ExtensionID)
	assert.NoError(t, err)
	require.NotNil(t, out)
	assert.Equal(t, common.ConcurrencyUUID, out.ConcurrencyUUID, "Concurrency mismatch")
	out, err = store.AsyncLoadCommon(common.Environment, common.ExtensionID).Common()
	assert.NoError(t, err)
	require.NotNil(t, out)
	assert.Equal(t, common.ConcurrencyUUID, out.ConcurrencyUUID, "Concurrency mismatch")
}
