package host

import (
	"io"
	"sync"

	"code.justin.tv/devhub/lib-lifecycle/src/lifecycle"
	"code.justin.tv/devhub/e2ml/libs/discovery"
	"code.justin.tv/devhub/e2ml/libs/discovery/protocol"
)

type loadFunction func() uint64

type RateLimitedStatus interface {
	Start(load loadFunction, windowLength int)
	Tick()
	io.Closer
}

type rateLimiter struct {
	reporter discovery.HostReporter
	flags    protocol.StatusFlags
	next     chan protocol.StatusFlags
	mgr      lifecycle.Manager
	mutex    sync.Mutex
}

func NewRateLimitedStatus(reporter discovery.HostReporter, flags protocol.StatusFlags) RateLimitedStatus {
	return &rateLimiter{
		reporter: reporter,
		flags:    flags,
		mgr:      lifecycle.NewManager(),
	}
}

func (r *rateLimiter) Close() error {
	r.mutex.Lock()
	defer r.mutex.Unlock()
	if r.next != nil {
		r.next <- protocol.Draining
		close(r.next)
		r.next = nil
	}
	return nil
}

func (r *rateLimiter) Start(load loadFunction, windowLength int) {
	r.mutex.Lock()
	defer r.mutex.Unlock()
	if r.next != nil {
		close(r.next)
	}
	r.next = make(chan protocol.StatusFlags, 10)
	next := r.next

	r.mgr.RunUntilComplete(func() { r.run(load, windowLength, next) })
}

func (r *rateLimiter) Tick() {
	r.mutex.Lock()
	defer r.mutex.Unlock()
	if r.next != nil {
		r.next <- r.flags
	}
}

func (r *rateLimiter) run(load loadFunction, windowLength int, next <-chan protocol.StatusFlags) {
	if windowLength < 1 {
		windowLength = 1
	}

	prevSent := ^uint64(0)
	window := make([]uint64, windowLength)
	sampled := uint64(0)
	index := 0
	for flags := range next {
		sampled -= window[index]
		window[index] = load()
		sampled += window[index]
		if index += 1; index == windowLength {
			index = 0
		}
		if prevSent != sampled {
			r.reporter.UpdateStatus(protocol.LoadFactor(sampled), flags)
			prevSent = sampled
		}
	}
}
