package db

import (
	"context"
	"errors"
	"fmt"
	"reflect"
	"time"

	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
)

type BulkWriter interface {
	Update(*BulkWriterElem)
	Upsert(*BulkWriterElem)
	Unset(*BulkWriterElem)
	Run()
	Shutdown()
}

var (
	errNoBulkWriterElemID = errors.New("no ID in BulkWriterElem")
	errInvalidModel       = errors.New("invalid model")
)

type BulkWriterElem struct {
	ID          string
	Fields      map[string]interface{}
	SetOnInsert interface{}
}

type mongoBulkWriter struct {
	collection *mongo.Collection
	opts       *MongoBulkWriterOptions

	prepare chan *rawUpdateElem
	add     chan []mongo.WriteModel
	stopped chan struct{}

	raw    map[string]*mongo.UpdateOneModel
	buffer []mongo.WriteModel

	limiter *time.Ticker
	ctx     context.Context
	cancel  context.CancelFunc
}

type MongoBulkWriterOptions struct {
	RPSLimit       float64
	FlushInterval  time.Duration
	Size           int
	ResultHandler  func(*mongo.BulkWriteResult, error)
	ErrorCollector func(err error)
}

func NewMongoBulkWriter(c *mongo.Collection, opts *MongoBulkWriterOptions) (BulkWriter, error) {
	if opts == nil {
		opts = &MongoBulkWriterOptions{
			RPSLimit:      5,
			FlushInterval: 2 * time.Second,
			Size:          1000,
		}
	}
	if opts.RPSLimit <= 0 || opts.RPSLimit > 50 {
		return nil, errors.New("RPSLimit: should be in (0, 50]")
	}
	if opts.FlushInterval <= 0 {
		return nil, errors.New("invalid flush interval")
	}
	if opts.Size <= 0 {
		return nil, errors.New("invalid buffer size")
	}
	writer := &mongoBulkWriter{
		collection: c,
		opts:       opts,
		prepare:    make(chan *rawUpdateElem, 500),
		add:        make(chan []mongo.WriteModel),
		stopped:    make(chan struct{}),
		raw:        make(map[string]*mongo.UpdateOneModel, opts.Size),
		buffer:     make([]mongo.WriteModel, 0, opts.Size),
	}
	writer.ctx, writer.cancel = context.WithCancel(context.Background())
	return writer, nil
}

func (w *mongoBulkWriter) Update(elem *BulkWriterElem) {
	if elem.ID == "" {
		w.handleError(errNoBulkWriterElemID)
		return
	}
	model := mongo.NewUpdateOneModel().SetFilter(bson.M{"_id": elem.ID})
	bsonM := bson.M{}
	for key, value := range elem.Fields {
		if key == "" {
			continue
		}
		bsonM[key] = value
	}
	if len(bsonM) == 0 {
		return
	}
	model = model.SetUpdate(bson.M{"$set": bsonM})
	select {
	case w.prepare <- &rawUpdateElem{id: elem.ID, model: model}:
	case <-w.ctx.Done():
	}
}

func (w *mongoBulkWriter) Upsert(elem *BulkWriterElem) {
	if elem.ID == "" {
		w.handleError(errNoBulkWriterElemID)
		return
	}
	model := mongo.NewUpdateOneModel().SetFilter(bson.M{"_id": elem.ID})
	setBsonM := bson.M{}
	for key, value := range elem.Fields {
		if key == "" {
			continue
		}
		setBsonM[key] = value
	}
	update := bson.M{}
	if len(setBsonM) > 0 {
		update["$set"] = setBsonM
	}
	if elem.SetOnInsert != nil {
		switch setOnInsert := elem.SetOnInsert.(type) {
		case map[string]interface{}:
			update["$setOnInsert"] = bson.M(setOnInsert)
		default:
			setOnInsertBsonM, err := getBsonMFromStruct(elem.SetOnInsert)
			if err != nil {
				w.handleError(err)
				return
			}
			if len(setOnInsertBsonM) > 0 {
				update["$setOnInsert"] = setOnInsertBsonM
			}
		}
	}
	// mongo cast error if update command try to change an equal field in "$set" and "$setOnInsert" operators.
	// need to remove such fields if they exist.
	update, err := removeEqualFieldsFromUpdate(update, "$set", "$setOnInsert")
	if err != nil {
		w.handleError(err)
		return
	}
	model = model.SetUpdate(update).SetUpsert(true)
	select {
	case w.prepare <- &rawUpdateElem{id: elem.ID, model: model}:
	case <-w.ctx.Done():
	}
}

func (w *mongoBulkWriter) Unset(elem *BulkWriterElem) {
	if elem.ID == "" {
		w.handleError(errNoBulkWriterElemID)
		return
	}
	model := mongo.NewUpdateOneModel().SetFilter(bson.M{"_id": elem.ID})
	bsonM := bson.M{}
	for key := range elem.Fields {
		if key == "" {
			continue
		}
		bsonM[key] = ""
	}
	if len(bsonM) == 0 {
		return
	}
	model = model.SetUpdate(bson.M{"$unset": bsonM})
	select {
	case w.prepare <- &rawUpdateElem{id: elem.ID, model: model}:
	case <-w.ctx.Done():
	}
}

func (w *mongoBulkWriter) Run() {
	rateLimit := 1000 / w.opts.RPSLimit
	w.limiter = time.NewTicker(time.Millisecond * time.Duration(rateLimit))
	updater := time.NewTicker(w.opts.FlushInterval)
	for {
		select {
		case <-updater.C:
			w.moveRawToBuffer()
			w.flush()
		case elem := <-w.prepare:
			w.raw[elem.id], _ = mergeModels(w.raw[elem.id], elem.model)
			if len(w.raw) >= w.opts.Size {
				w.moveRawToBuffer()
				w.flush()
			}
		case <-w.ctx.Done():
			toPrepareLeft := len(w.prepare)
			for i := 0; i < toPrepareLeft; i++ {
				elem := <-w.prepare
				w.raw[elem.id], _ = mergeModels(w.raw[elem.id], elem.model)
			}
			w.moveRawToBuffer()
			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
		LOOP:
			for len(w.buffer) > 0 {
				w.flush()
				select {
				case <-ctx.Done():
					break LOOP
				default:
				}
			}
			w.limiter.Stop()
			updater.Stop()
			cancel()
			w.stopped <- struct{}{}
			return
		}
	}
}

func (w *mongoBulkWriter) Shutdown() {
	w.cancel()
	<-w.stopped
}

func (w *mongoBulkWriter) flush() {
	<-w.limiter.C
	if len(w.buffer) == 0 {
		return
	}
	var bulk []mongo.WriteModel
	overhead := len(w.buffer) - w.opts.Size
	switch {
	case overhead > 0:
		bulk = w.buffer[:w.opts.Size]
		w.buffer = w.buffer[w.opts.Size:]
	default:
		bulk = w.buffer
		w.buffer = make([]mongo.WriteModel, 0, w.opts.Size)
	}
	opts := options.BulkWrite().SetOrdered(false)
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()
	res, err := w.collection.BulkWrite(ctx, bulk, opts)

	if w.opts.ResultHandler != nil {
		w.opts.ResultHandler(res, err)
	}
}

func (w *mongoBulkWriter) moveRawToBuffer() {
	for _, v := range w.raw {
		if v == nil {
			continue
		}
		w.buffer = append(w.buffer, v)
	}
	w.raw = make(map[string]*mongo.UpdateOneModel, w.opts.Size)
}

type rawUpdateElem struct {
	id    string
	model *mongo.UpdateOneModel
}

func mergeModels(old *mongo.UpdateOneModel, new *mongo.UpdateOneModel) (*mongo.UpdateOneModel, error) {
	if old == nil {
		return new, nil
	}
	updateOld, ok := old.Update.(bson.M)
	if !ok {
		return nil, errInvalidModel
	}
	updateNew, ok := new.Update.(bson.M)
	if !ok {
		return old, errInvalidModel
	}
	for newOperator, newVal := range updateNew {
		if updateOld[newOperator] == nil {
			updateOld[newOperator] = bson.M{}
		}
		newMap, ok := newVal.(bson.M)
		if !ok {
			return nil, errInvalidModel
		}

		oldMap, ok := updateOld[newOperator].(bson.M)
		if !ok {
			return nil, errInvalidModel
		}

		for k, v := range newMap {
			oldMap[k] = v
			var operatorsWithEmptyDoc []string
			for oldOperator, oldVal := range updateOld {
				if oldOperator != newOperator {
					oldMapForOtherOp, ok := oldVal.(bson.M)
					if !ok {
						return nil, errInvalidModel
					}
					delete(oldMapForOtherOp, k)
					if len(oldMapForOtherOp) == 0 {
						operatorsWithEmptyDoc = append(operatorsWithEmptyDoc, oldOperator)
					}
				}
			}
			for _, op := range operatorsWithEmptyDoc {
				delete(updateOld, op)
			}
		}
	}
	if new.Upsert != nil && *new.Upsert {
		old.SetUpsert(true)
	}
	return old, nil
}

func getBsonMFromStruct(v interface{}) (bson.M, error) {
	value := reflect.ValueOf(v)
	if value.Kind() == reflect.Ptr {
		value = value.Elem()
	}
	if value.Kind() != reflect.Struct {
		return nil, fmt.Errorf("%v not struct or pointer to struct ", v)
	}
	res := bson.M{}
	for i := 0; i < value.NumField(); i++ {
		key := value.Type().Field(i).Tag.Get("bson")
		if key == "-" || key == "" {
			continue
		}
		res[key] = value.Field(i).Interface()
	}
	return res, nil
}

func (w *mongoBulkWriter) handleError(err error) {
	if w.opts.ErrorCollector != nil {
		w.opts.ErrorCollector(err)
	}
}

// removeEqualFieldsFromUpdate remove fields to be updated from operator op2 document if those are contained in
// operator op1 document. Remove operator op2 document if that is empty
func removeEqualFieldsFromUpdate(update bson.M, op1, op2 string) (bson.M, error) {
	if update[op1] == nil || update[op2] == nil {
		return update, nil
	}
	m1, ok1 := update[op1].(bson.M)
	m2, ok2 := update[op2].(bson.M)
	if !(ok1 && ok2) {
		return nil, fmt.Errorf("invalid operation document")
	}
	for k := range m1 {
		if m2[k] != nil {
			delete(m2, k)
			if len(m2) == 0 {
				delete(update, op2)
				break
			}
		}
	}
	return update, nil
}
