package artifact

import (
	"errors"
	"fmt"
	"io"
	"sync"
	"time"
)

var (
	writeQueueEntryPool = sync.Pool{
		New: func() interface{} {
			return &writeQueueEntry{
				wc: make(chan struct{}),
			}
		},
	}
)

type writeQueueEntry struct {
	capacity int
	wc       chan struct{}
}

const (
	rateLimitTickRate     = 4 // how many times per second to read for rate limit metering
	rateLimitBufferFactor = 2 // factor to buffer for burst reading
)

type rateLimitReader struct {
	wr          io.ReadCloser
	tokens      int
	tokensMutex *sync.Mutex
	ticker      *time.Ticker
}

func newRateLimitReader(maxTransferRate int, wr io.ReadCloser) *rateLimitReader {
	r := new(rateLimitReader)
	r.wr = wr
	r.ticker = time.NewTicker(time.Second / rateLimitTickRate)
	r.tokensMutex = new(sync.Mutex)
	go func() {
		var tokensIncrement = int(maxTransferRate / rateLimitTickRate)
		var maxTokens = int(maxTransferRate * rateLimitBufferFactor)
		for range r.ticker.C {
			r.addTokens(tokensIncrement, maxTokens)
		} // exits after tick.Stop()
	}()
	return r
}

func (r *rateLimitReader) Read(p []byte) (n int, err error) {
	var readLength = r.getTokens(len(p))
	if readLength == 0 { // wait half a tick if tokens are empty
		time.Sleep(time.Second / rateLimitTickRate / 2)
		return 0, nil
	}
	nr, err := r.wr.Read(p[0:readLength])
	return nr, err
}

func (r *rateLimitReader) addTokens(tokensIncrement, maxTokens int) {
	r.tokensMutex.Lock()
	if r.tokens < maxTokens {
		r.tokens += tokensIncrement
	}
	r.tokensMutex.Unlock()
}

func (r *rateLimitReader) getTokens(capacity int) int {
	r.tokensMutex.Lock()
	var ret int
	if r.tokens > capacity {
		ret = capacity
		r.tokens -= capacity
	} else {
		ret = r.tokens
		r.tokens = 0
	}
	r.tokensMutex.Unlock()
	return ret
}

func (r *rateLimitReader) Close() error {
	r.ticker.Stop()
	return r.wr.Close()
}

type rateLimitWriterAt struct {
	ww         io.WriterAt
	tokens     int
	writeQueue []*writeQueueEntry
	queueMutex *sync.Mutex
	ticker     *time.Ticker
	closed     bool
}

func newRateLimitWriterAt(maxTransferRate int, ww io.WriterAt) *rateLimitWriterAt {
	w := new(rateLimitWriterAt)
	w.ww = ww
	w.ticker = time.NewTicker(time.Second / rateLimitTickRate)
	w.queueMutex = new(sync.Mutex)
	go func() {
		var tokensIncrement = int(maxTransferRate / rateLimitTickRate)
		var maxTokens = int(maxTransferRate * rateLimitBufferFactor)
		for range w.ticker.C {
			if w.tokens < maxTokens {
				w.tokens += tokensIncrement
			}
			w.queueMutex.Lock()
			for len(w.writeQueue) > 0 && w.tokens >= w.writeQueue[0].capacity {
				w.writeQueue[0].wc <- struct{}{}
				w.tokens -= w.writeQueue[0].capacity
				w.writeQueue = w.writeQueue[1:]
			}
			w.queueMutex.Unlock()
		} // exits after tick.Stop()
	}()
	return w
}

func (w *rateLimitWriterAt) WriteAt(p []byte, off int64) (n int, err error) {
	err = w.waitForTokens(len(p))
	if err != nil {
		return 0, fmt.Errorf("unable to write: %v", err)
	}
	return w.ww.WriteAt(p, off)
}

func (w *rateLimitWriterAt) waitForTokens(capacity int) error {
	queueEntry := writeQueueEntryPool.Get().(*writeQueueEntry)
	queueEntry.capacity = capacity
	w.queueMutex.Lock()
	if w.closed {
		w.queueMutex.Unlock()
		writeQueueEntryPool.Put(queueEntry)
		return errors.New("writer is closed")
	}
	w.writeQueue = append(w.writeQueue, queueEntry)
	w.queueMutex.Unlock()
	<-queueEntry.wc
	writeQueueEntryPool.Put(queueEntry)
	return nil
}

func (w *rateLimitWriterAt) Close() error {
	w.ticker.Stop()
	w.queueMutex.Lock()
	w.closed = true
	for len(w.writeQueue) > 0 {
		select {
		// use non-blocking send when closing
		case w.writeQueue[0].wc <- struct{}{}:
		default:
		}
		w.writeQueue = w.writeQueue[1:]
	}
	w.queueMutex.Unlock()
	return nil
}
