package history

import (
	"sync/atomic"
	"time"
	"unsafe"

	"code.justin.tv/devhub/e2ml/libs/metrics"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/protocol"
)

type rstats struct {
	limited metrics.Count
}

type rateLimited struct {
	stats   rstats
	address stream.Address
	entry   unsafe.Pointer
	err     error
	seq     int32
	samples []int64
	minTime int64
}

func NewRateLimited(address stream.Address, count uint8, over time.Duration, limited metrics.Count) stream.History {
	return &rateLimited{
		stats: rstats{
			limited: limited,
		},
		address: address,
		entry:   unsafe.Pointer(&initial),
		err:     protocol.ErrRateLimited(int(count), over),
		samples: make([]int64, count),
		minTime: int64(over),
	}
}

func (r *rateLimited) Address() stream.Address { return r.address }

func (r *rateLimited) Write(msg stream.Message) error {
	// check rate limit
	if !r.check(time.Now()) {
		r.stats.limited.Add(1)
		return r.err
	}
	atomic.StorePointer(&r.entry, unsafe.Pointer(&msg))
	return nil
}

func (r *rateLimited) Close() error {
	atomic.StorePointer(&r.entry, unsafe.Pointer(&closed))
	return nil
}

func (r *rateLimited) Next(src stream.SourceID, pos stream.Position) (stream.Message, bool) {
	e := r.current()
	if src == e.Source() && pos >= e.At().End || e == closed {
		return nil, false
	}
	return e, true
}

func (r *rateLimited) Last() (stream.SourceID, stream.Position) {
	e := r.current()
	if e == closed {
		return stream.None, stream.Origin
	}
	return e.Source(), e.At().End
}

func (r *rateLimited) current() stream.Message {
	return *(*stream.Message)(atomic.LoadPointer(&r.entry))
}

func (r *rateLimited) check(now time.Time) bool {
	unix := now.UnixNano()
	index := int(atomic.AddInt32(&r.seq, 1)) % len(r.samples)
	prev := atomic.SwapInt64(&r.samples[index], unix)
	return unix-prev > r.minTime
}
