package screening

import (
	"context"
	"testing"
	"time"

	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/suite"

	"a.yandex-team.ru/infra/walle/server/go/internal/lib"
	lockrepo "a.yandex-team.ru/infra/walle/server/go/internal/utilities/repository"
	"a.yandex-team.ru/infra/walle/server/go/internal/utilities/repository/mocks"
)

type ShardManagerSuite struct {
	suite.Suite
	manager   *shardManager
	instances []string
	hostname  string
}

func (suite *ShardManagerSuite) SetupSuite() {
	loggers, _ := lib.NewLoggers(
		map[string]lib.LoggerConfig{"test": {Level: "debug", Format: "console", Paths: []string{"stdout"}}},
	)
	var err error
	lockRepo := &mocks.LockRepo{}
	lockRepo.On("TryLock", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(
		func(ctx context.Context, id string, owner string, until time.Time) bool {
			suite.hostname = "this_host"
			suite.instances = append(suite.instances, "this_host")
			return true
		},
		nil,
	)
	lockRepo.On("FindOwners", mock.Anything, mock.Anything).
		Return(func(ctx context.Context, filter *lockrepo.LockFilter) []string { return suite.instances }, nil)
	lockRepo.On("Unlock", mock.Anything, mock.Anything, mock.Anything).Return(nil)
	lockRepo.On("Delete", mock.Anything, mock.Anything).Return(nil)

	suite.manager, err = newShardManager(2, 12, lockRepo, loggers["test"])
	suite.Require().NoError(err)
	suite.manager.instance = "this_host"

}

func (suite *ShardManagerSuite) TearDownSuite() {
	suite.manager.clean()
}

func (suite *ShardManagerSuite) TestGetShards() {
	shards, err := suite.manager.getShards(context.Background())
	suite.Require().NoError(err)
	expected := make([]*shard, suite.manager.shardsNum)
	for i := 0; i < suite.manager.shardsNum; i++ {
		expected[i] = &shard{i}
	}
	suite.Assert().Equal(expected, shards)
}

func (suite *ShardManagerSuite) TestGetShardsWithManyInstances() {
	suite.instances = append(suite.instances, "instance-2", "instance-3")
	shards, err := suite.manager.getShards(context.Background())
	suite.Require().NoError(err)
	expected := []*shard{{1}, {4}, {5}, {9}}
	suite.Assert().Equal(expected, shards)
	suite.instances = []string{suite.hostname}
}

func (suite *ShardManagerSuite) TestShardRedistributing() {
	suite.instances = append(suite.instances, "instance-2", "instance-3")
	shards, err := suite.manager.getShards(context.Background())
	suite.Require().NoError(err)
	expected := []*shard{{1}, {4}, {5}, {9}}
	suite.Assert().Equal(expected, shards)

	// one instance is gone
	suite.instances = []string{suite.hostname, "instance-2"}
	shards, err = suite.manager.getShards(context.Background())
	suite.Require().NoError(err)
	expected = []*shard{{1}, {2}, {4}, {5}, {7}, {9}}
	suite.Assert().Equal(expected, shards)

	// the gone instance returns
	suite.instances = append(suite.instances, "instance-3")
	shards, err = suite.manager.getShards(context.Background())
	suite.Require().NoError(err)
	expected = []*shard{{1}, {4}, {5}, {9}}
	suite.Assert().Equal(expected, shards)

	suite.instances = []string{suite.hostname}

}

func TestShardManager(t *testing.T) {
	suite.Run(t, new(ShardManagerSuite))
}
