package repos

import (
	"context"
	"testing"
	"time"

	"github.com/stretchr/testify/suite"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo/readpref"

	"a.yandex-team.ru/infra/walle/server/go/internal/lib/db"
)

type HostNetworkTestSuite struct {
	suite.Suite
	repo     *HostNetworkRepo
	networks map[string]*HostNetwork
}

func (suite *HostNetworkTestSuite) SetupSuite() {
	mongodb, err := db.GetTestingMongoDB()
	suite.Require().NoError(err)
	suite.repo = NewHostNetworkRepo(mongodb, readpref.Primary())

	suite.networks = make(map[string]*HostNetwork)
	suite.networks["1"] = &HostNetwork{HostUUID: "1", ActiveMac: "11:11:11:11:11:11"}

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	_, err = suite.repo.collection.InsertOne(ctx, suite.networks["1"])
	suite.Require().NoError(err)
}

func (suite *HostNetworkTestSuite) TestGetOrCreate() {
	type testcase struct {
		uuid     string
		expected *HostNetwork
	}
	var cases []*testcase
	for uuid, network := range suite.networks {
		cases = append(cases, &testcase{uuid: uuid, expected: network})
	}
	cases = append(cases, &testcase{uuid: "new", expected: &HostNetwork{HostUUID: "new"}})
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	for _, c := range cases {
		returned, err := suite.repo.GetOrCreate(ctx, c.uuid)
		suite.Assert().NoError(err)
		suite.Equal(c.expected, returned)
		suite.networks[c.uuid] = c.expected
	}

	c, err := suite.repo.collection.Find(ctx, bson.D{})
	suite.Require().NoError(err)
	var found []*HostNetwork
	suite.Require().NoError(c.All(ctx, &found))
	suite.Assert().Equal(len(suite.networks), len(found))
	for _, item := range found {
		suite.Equal(suite.networks[item.HostUUID], item)
	}
}

func TestHostNetworkRepo(t *testing.T) {
	suite.Run(t, new(HostNetworkTestSuite))
}
