package repos

import (
	"context"
	"errors"
	"math"
	"time"

	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
	"go.mongodb.org/mongo-driver/mongo/readpref"

	"a.yandex-team.ru/infra/walle/server/go/internal/lib/db"
	"a.yandex-team.ru/infra/walle/server/go/internal/lib/juggler"
	"a.yandex-team.ru/infra/walle/server/go/internal/lib/monitoring"
)

const (
	healthCollectionName = "health_checks"
)

type HealthRepo struct {
	collection *mongo.Collection
}

func NewHealthRepo(db *mongo.Database, pref *readpref.ReadPref) *HealthRepo {
	return &HealthRepo{
		collection: db.Collection(healthCollectionName, options.Collection().SetReadPreference(pref)),
	}
}

func (repo *HealthRepo) BulkUpsert(ctx context.Context, checks []*juggler.HostCheck) (*mongo.BulkWriteResult, error) {
	models := make([]mongo.WriteModel, len(checks))
	for i, check := range checks {
		model := mongo.NewUpdateOneModel().
			SetFilter(bson.D{
				{Key: "_id", Value: check.ID},
				{Key: "timestamp", Value: bson.D{{Key: "$lt", Value: check.Timestamp}}},
			}).
			SetUpdate(bson.D{{Key: "$set", Value: check}}).
			SetUpsert(true)
		models[i] = model
	}
	opts := options.BulkWrite().SetOrdered(false)
	result, err := repo.collection.BulkWrite(ctx, models, opts)
	if err != nil && !mongo.IsDuplicateKeyError(err) {
		return nil, err
	}
	return result, nil
}

func (repo *HealthRepo) FindIDMap(
	ctx context.Context,
	filter *HostCheckFilter,
) (map[juggler.WalleCheckKey]*juggler.HostCheck, error) {
	opts := options.Find()
	filters := bson.D{}
	if filter != nil {
		filters = filter.getBSON()
	}
	checks := make(map[juggler.WalleCheckKey]*juggler.HostCheck)
	cursor, err := repo.collection.Find(ctx, filters, opts)
	if err != nil {
		return nil, err
	}
	defer func() {
		_ = cursor.Close(ctx)
	}()
	for cursor.Next(ctx) {
		check := &juggler.HostCheck{}
		if err := cursor.Decode(&check); err != nil {
			return nil, err
		}
		checks[check.ID] = check
	}
	if err := cursor.Err(); err != nil {
		return nil, err
	}

	return checks, nil
}

func (repo *HealthRepo) Find(
	ctx context.Context,
	filter *HostCheckFilter,
) ([]*juggler.HostCheck, error) {
	opts := options.Find()
	filters := bson.D{}
	if filter != nil {
		filters = filter.getBSON()
	}
	var checks []*juggler.HostCheck
	cursor, err := repo.collection.Find(ctx, filters, opts)
	if err != nil {
		return nil, err
	}
	defer func() {
		_ = cursor.Close(ctx)
	}()
	for cursor.Next(ctx) {
		check := &juggler.HostCheck{}
		if err := cursor.Decode(&check); err != nil {
			return nil, err
		}
		checks = append(checks, check)
	}
	if err := cursor.Err(); err != nil {
		return nil, err
	}

	return checks, nil
}

func (repo *HealthRepo) Select(ctx context.Context, filter *HostCheckFilter, keys []string) (*db.MongoSelection, error) {
	filters := bson.D{}
	if filter != nil {
		filters = filter.getBSON()
	}
	return db.NewMongoSelection(ctx, repo.collection, filters, options.Find(), keys)
}

func (repo *HealthRepo) AggregateMetrics(ctx context.Context, ageBuckets []float64) (*HealthMetricAggregation, error) {
	filterStage := bson.D{
		{
			Key:   "$match",
			Value: bson.M{"type": bson.M{"$in": juggler.AllCheckTypes()}},
		},
	}
	statusCountPipeline := bson.A{
		bson.M{
			"$group": bson.M{
				"_id": bson.M{
					"type":   "$type",
					"status": "$status",
				},
				"total": bson.M{"$sum": 1},
			},
		},
		bson.M{
			"$project": bson.M{
				"type":   "$_id.type",
				"status": "$_id.status",
				"_id":    0,
				"total":  1,
			},
		},
	}
	pipeline := mongo.Pipeline{
		filterStage,
		bson.D{{
			Key: "$facet",
			Value: bson.M{
				"statuses":      statusCountPipeline,
				"ages":          healthCheckAgeAggregationPipeline("timestamp", ageBuckets),
				"agesEffective": healthCheckAgeAggregationPipeline("effective_timestamp", ageBuckets),
			},
		}},
	}
	cursor, err := repo.collection.Aggregate(ctx, pipeline, options.Aggregate())
	if err != nil {
		return nil, err
	}
	var result []*HealthMetricAggregation
	if err = cursor.All(ctx, &result); err != nil {
		return nil, err
	}
	if len(result) == 0 {
		return nil, errors.New("no data to aggregate")
	}
	return result[0], nil
}

type HealthBulkWriterOptions struct {
	Size          int
	FlushInterval time.Duration
	RPSLimit      float64
	Handler       func(*mongo.BulkWriteResult, error)
}

type HealthBulkWriter interface {
	Run()
	Add([]*juggler.HostCheck)
	Shutdown()
}

func (repo *HealthRepo) NewBulkWriter(opts *HealthBulkWriterOptions) (HealthBulkWriter, error) {
	return newHealthBulkWriter(opts, repo)
}

type healthBulkWriter struct {
	repo    *HealthRepo
	opts    *HealthBulkWriterOptions
	add     chan []*juggler.HostCheck
	limiter *time.Ticker
	ctx     context.Context
	cancel  context.CancelFunc
	stopped chan struct{}
	buffer  []*juggler.HostCheck
}

func newHealthBulkWriter(opts *HealthBulkWriterOptions, repo *HealthRepo) (*healthBulkWriter, error) {
	if opts.FlushInterval <= 0 {
		return nil, errors.New("invalid flush interval")
	}
	if opts.RPSLimit <= 0 || opts.RPSLimit > 50 {
		return nil, errors.New("RPSLimit: should be in (0, 50]")
	}
	if opts.Size <= 0 {
		return nil, errors.New("invalid buffer size")
	}

	writer := &healthBulkWriter{
		repo:    repo,
		opts:    opts,
		add:     make(chan []*juggler.HostCheck),
		stopped: make(chan struct{}),
	}
	writer.ctx, writer.cancel = context.WithCancel(context.Background())
	return writer, nil
}

func (w *healthBulkWriter) Run() {
	rateLimit := 1000 / w.opts.RPSLimit
	w.limiter = time.NewTicker(time.Millisecond * time.Duration(rateLimit))
	updater := time.NewTicker(w.opts.FlushInterval)
	metricTicker := time.NewTicker(time.Second)
	for {
		select {
		case <-updater.C:
			w.flush()
		case checks := <-w.add:
			left := len(checks)
			for left > 0 {
				if len(w.buffer) >= w.opts.Size {
					w.flush()
					continue
				}
				gap := w.opts.Size - len(w.buffer)
				if left < gap {
					gap = left
				}
				w.buffer = append(w.buffer, checks[len(checks)-left:len(checks)-left+gap]...)
				left -= gap
			}
		case <-metricTicker.C:
			free := math.Max(float64(w.opts.Size-len(w.buffer)), 0)
			HealthBulkWriterMetrics.FreeBuffer.Update(free)
		case <-w.ctx.Done():
			timer := time.NewTimer(10 * time.Second)
		LOOP:
			for len(w.buffer) > 0 {
				w.flush()
				select {
				case <-timer.C:
					break LOOP
				default:
				}
			}
			w.stopped <- struct{}{}
			w.limiter.Stop()
			updater.Stop()
			metricTicker.Stop()
			return
		}
	}
}

func (w *healthBulkWriter) Add(checks []*juggler.HostCheck) {
	select {
	case w.add <- checks:
	case <-w.ctx.Done():
	}
}

func (w *healthBulkWriter) flush() {
	<-w.limiter.C
	result, err := w.write()
	if w.opts.Handler != nil {
		w.opts.Handler(result, err)
	}
}

func (w *healthBulkWriter) write() (*mongo.BulkWriteResult, error) {
	if len(w.buffer) == 0 {
		return nil, nil
	}
	var bulk []*juggler.HostCheck
	overhead := len(w.buffer) - w.opts.Size
	switch {
	case overhead > 0:
		bulk = w.buffer[:w.opts.Size]
		w.buffer = w.buffer[w.opts.Size:]
	default:
		bulk = w.buffer
		w.buffer = make([]*juggler.HostCheck, 0)
	}
	now := time.Now()
	result, err := w.repo.BulkUpsert(w.ctx, bulk)
	monitoring.MeasureSecondsSince(HealthBulkWriterMetrics.FlushTime, now)
	if result != nil {
		HealthBulkWriterMetrics.ModifiedChecks.Update(float64(result.ModifiedCount))
		HealthBulkWriterMetrics.UpsertedChecks.Update(float64(result.UpsertedCount))
	}
	return result, err
}

func (w *healthBulkWriter) Shutdown() {
	w.cancel()
	<-w.stopped
}

type HostCheckFilter struct {
	Keys  []juggler.WalleCheckKey
	Types []juggler.CheckType
	FQDN  []HostName
}

func (f *HostCheckFilter) getBSON() bson.D {
	filters := bson.D{}
	if len(f.Keys) > 0 {
		filters = append(filters, bson.E{Key: "_id", Value: bson.D{{Key: "$in", Value: f.Keys}}})
	}
	if len(f.Types) > 0 {
		filters = append(filters, bson.E{Key: "type", Value: bson.D{{Key: "$in", Value: f.Types}}})
	}
	if len(f.FQDN) > 0 {
		filters = append(filters, bson.E{Key: "fqdn", Value: bson.D{{Key: "$in", Value: f.FQDN}}})
	}
	return filters
}

type HealthMetricAggregation struct {
	Statuses      []*HealthCheckStatusAggregation `bson:"statuses"`
	Ages          []*HealthCheckAgeAggregation    `bson:"ages"`
	AgesEffective []*HealthCheckAgeAggregation    `bson:"agesEffective"`
}

type HealthCheckStatusAggregation struct {
	Type   juggler.CheckType   `bson:"type"`
	Status juggler.WalleStatus `bson:"status"`
	Total  int                 `bson:"total"`
}

type HealthCheckAgeAggregation struct {
	Type  juggler.CheckType `bson:"type"`
	Age   float64           `bson:"age"`
	Total int64             `bson:"total"`
}

func healthCheckAgeAggregationPipeline(ageFiled string, buckets []float64) (pipeline bson.A) {
	if len(buckets) < 2 {
		return
	}
	now := time.Now().Unix()
	var branches bson.A
	for i, bucket := range buckets[1:] {
		branches = append(branches, bson.M{
			"case": bson.M{"$lt": bson.A{"$age", bucket}},
			"then": buckets[i],
		})
	}
	pipeline = bson.A{
		bson.M{
			"$project": bson.M{
				"_id":  0,
				"type": 1,
				"age":  bson.M{"$subtract": bson.A{now, "$" + ageFiled}},
			},
		},
		bson.M{
			"$group": bson.M{
				"_id": bson.M{
					"type": "$type",
					"age": bson.M{
						"$switch": bson.M{
							"branches": branches,
							"default":  buckets[len(buckets)-1],
						},
					},
				},
				"total": bson.M{"$sum": 1},
			},
		},
		bson.M{
			"$project": bson.M{
				"_id":   0,
				"type":  "$_id.type",
				"age":   "$_id.age",
				"total": 1,
			},
		},
	}
	return pipeline
}
