package db

import (
	"context"
	"fmt"
	"testing"
	"time"

	"github.com/stretchr/testify/suite"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
)

type BulkWriterTestSuite struct {
	suite.Suite
	collection *mongo.Collection
	opts       *MongoBulkWriterOptions
}

func (suite *BulkWriterTestSuite) SetupSuite() {
	db, err := GetTestingMongoDB()
	suite.Require().NoError(err)
	suite.collection = db.Collection("test")
	suite.opts = &MongoBulkWriterOptions{
		RPSLimit:      50,
		FlushInterval: 10 * time.Millisecond,
		Size:          1000,
		ResultHandler: func(result *mongo.BulkWriteResult, err error) {
			suite.Assert().NoError(err)
		},
		ErrorCollector: func(err error) {
			suite.Assert().NoError(err)
		},
	}
}

func (suite *BulkWriterTestSuite) TearDownTest() {
	_, err := suite.collection.DeleteMany(context.Background(), bson.D{})
	suite.Require().NoError(err)
}

type testDocument struct {
	ID     string `bson:"_id"`
	Field0 string `bson:"field-0"`
	Field1 string `bson:"field-1"`
	Field2 string `bson:"field-2"`
}

func (suite *BulkWriterTestSuite) TestUpsertAndUpdate() {
	writer, err := NewMongoBulkWriter(suite.collection, suite.opts)
	suite.Require().NoError(err)
	go writer.Run()
	elems := []*BulkWriterElem{
		{ID: "elem-0", Fields: map[string]interface{}{"field-0": "value-00"}},
		{ID: "elem-1", Fields: map[string]interface{}{"field-0": "value-01", "field-1": "old-value-11"}},
		{ID: "elem-2", Fields: map[string]interface{}{"field-0": "value-02"}},
	}
	for _, elem := range elems {
		writer.Upsert(elem)
	}
	elems = []*BulkWriterElem{
		{ID: "elem-1", Fields: map[string]interface{}{"field-1": "value-11", "field-2": "value-21"}},
		{ID: "elem-2", Fields: map[string]interface{}{"field-1": "value-12"}},
	}
	for _, elem := range elems {
		writer.Update(elem)
	}
	suite.checkResults(writer, []testDocument{
		{ID: "elem-0", Field0: "value-00"},
		{ID: "elem-1", Field0: "value-01", Field1: "value-11", Field2: "value-21"},
		{ID: "elem-2", Field0: "value-02", Field1: "value-12"},
	})
}

func (suite *BulkWriterTestSuite) TestSetOnInsert() {
	_, err := suite.collection.InsertOne(context.Background(), &testDocument{ID: "elem-0", Field0: "old-value-00"})
	suite.Require().NoError(err)
	writer, err := NewMongoBulkWriter(suite.collection, suite.opts)
	suite.Require().NoError(err)
	go writer.Run()
	writer.Upsert(&BulkWriterElem{
		ID:          "elem-0",
		Fields:      map[string]interface{}{"field-0": "value-00"},
		SetOnInsert: map[string]interface{}{"field-1": "value-10"},
	})
	writer.Upsert(&BulkWriterElem{
		ID:          "elem-1",
		Fields:      map[string]interface{}{"field-0": "value-01"},
		SetOnInsert: map[string]interface{}{"field-1": "value-11"},
	})
	writer.Upsert(&BulkWriterElem{
		ID:     "elem-2",
		Fields: map[string]interface{}{"field-0": "old-value-02"},
	})
	writer.Upsert(&BulkWriterElem{
		ID:          "elem-2",
		SetOnInsert: map[string]interface{}{"field-0": "value-02"},
	})
	writer.Upsert(&BulkWriterElem{
		ID:     "elem-3",
		Fields: map[string]interface{}{"field-0": "value-03"},
		SetOnInsert: &testDocument{
			ID:     "elem-3",
			Field1: "value-13",
			Field2: "value-23",
		},
	})
	suite.checkResults(writer, []testDocument{
		{ID: "elem-0", Field0: "value-00"},
		{ID: "elem-1", Field0: "value-01", Field1: "value-11"},
		{ID: "elem-2", Field0: "value-02"},
		{ID: "elem-3", Field0: "value-03", Field1: "value-13", Field2: "value-23"},
	})
}

func (suite *BulkWriterTestSuite) TestUnset() {
	_, err := suite.collection.InsertOne(
		context.Background(),
		&testDocument{ID: "elem-0", Field0: "value-00", Field1: "value-10"},
	)
	suite.Require().NoError(err)
	writer, err := NewMongoBulkWriter(suite.collection, suite.opts)
	suite.Require().NoError(err)
	go writer.Run()
	writer.Unset(&BulkWriterElem{ID: "elem-0", Fields: map[string]interface{}{"field-0": ""}})
	suite.checkResults(writer, []testDocument{{ID: "elem-0", Field1: "value-10"}})
}

func (suite *BulkWriterTestSuite) checkResults(writer BulkWriter, expected []testDocument) {
	writer.Shutdown()
	cur, err := suite.collection.Find(context.Background(), bson.D{}, options.Find().SetSort(bson.M{"_id": 1}))
	suite.Require().NoError(err)
	defer func() {
		_ = cur.Close(context.Background())
	}()
	var results []testDocument
	suite.Require().NoError(cur.All(context.Background(), &results))
	suite.Require().Equal(expected, results)
}

func (suite *BulkWriterTestSuite) TestPeriodicFlush() {
	bulkSize := 10
	bulkNumber := 15
	flushed := make(chan struct{}, 1)
	var cnt int64
	opts := &MongoBulkWriterOptions{
		RPSLimit:      50,
		FlushInterval: 10 * time.Millisecond,
		Size:          1000,
		ResultHandler: func(result *mongo.BulkWriteResult, err error) {
			suite.Assert().NoError(err)
			cnt += result.UpsertedCount + result.ModifiedCount
			if cnt == int64(bulkSize) {
				cnt = 0
				flushed <- struct{}{}
			}
		},
	}
	writer, err := NewMongoBulkWriter(suite.collection, opts)
	suite.Require().NoError(err)
	go writer.Run()
	for i := 0; i < bulkNumber; i++ {
		for j := 0; j < bulkSize; j++ {
			writer.Upsert(&BulkWriterElem{ID: fmt.Sprintf("elem-%d", j), Fields: map[string]interface{}{"k": i*10 + j}})
		}
		<-flushed
	}
	writer.Shutdown()
}

func (suite *BulkWriterTestSuite) TestOverSizeFlush() {
	writerSize := 5
	bulkSize := 10
	bulkNumber := 15
	last := make(chan struct{})
	var cnt int64
	opts := &MongoBulkWriterOptions{
		RPSLimit:      50,
		FlushInterval: 24 * time.Hour,
		Size:          writerSize,
		ResultHandler: func(result *mongo.BulkWriteResult, err error) {
			suite.Assert().NoError(err)
			flushSize := result.UpsertedCount + result.ModifiedCount
			suite.Assert().Equal(int64(writerSize), flushSize)
			cnt += flushSize
			if cnt == int64(bulkSize*bulkNumber) {
				last <- struct{}{}
			}
		},
	}
	writer, err := NewMongoBulkWriter(suite.collection, opts)
	suite.Require().NoError(err)
	go writer.Run()
	for i := 0; i < bulkNumber; i++ {
		for j := 0; j < bulkSize; j++ {
			writer.Upsert(&BulkWriterElem{ID: fmt.Sprintf("elem-%d", j), Fields: map[string]interface{}{"k": i*10 + j}})
		}
	}
	<-last
	writer.Shutdown()
}

func (suite *BulkWriterTestSuite) TestRateLimit() {
	ticks := 0
	opts := &MongoBulkWriterOptions{
		RPSLimit:      50,
		FlushInterval: 500 * time.Millisecond,
		Size:          10,
		ResultHandler: func(result *mongo.BulkWriteResult, err error) {
			suite.Assert().NoError(err)
			ticks++
		},
	}
	writer, err := NewMongoBulkWriter(suite.collection, opts)
	suite.Require().NoError(err)
	start := time.Now()
	go writer.Run()
	bulkSize := 10
	bulkNumber := 15
	for i := 0; i < bulkNumber; i++ {
		for j := 0; j < bulkSize; j++ {
			writer.Upsert(&BulkWriterElem{
				ID:     fmt.Sprintf("elem-%d-%d", i, j),
				Fields: map[string]interface{}{"k": i*bulkNumber + j},
			})
		}
	}
	writer.Shutdown()
	elapsed := time.Since(start)
	ticksLimit := bulkNumber * bulkSize / opts.Size
	suite.Require().GreaterOrEqual(ticks, ticksLimit)
	timeLimit := time.Duration(float64(ticksLimit)*1000/opts.RPSLimit) * time.Millisecond
	suite.Require().GreaterOrEqual(elapsed.Microseconds(), timeLimit.Microseconds())
}

func TestBulkWriter(t *testing.T) {
	suite.Run(t, new(BulkWriterTestSuite))
}
