package protoqueue

import (
	"encoding/binary"
	"errors"
	"fmt"
	"sync"
	"sync/atomic"
	"unsafe"

	"a.yandex-team.ru/security/gideon/gideon/pkg/events"
)

var ErrFullQueue = errors.New("queue is full")

type eventsQueue struct {
	maxSize  int
	freeBufs []*bytesBuf
	fullBufs []*bytesBuf
	curBuf   unsafe.Pointer
	mu       sync.Mutex
}

type bytesBuf struct {
	buf     []byte
	off     uint32
	maxSize uint32
	mu      sync.RWMutex
}

func NewBytesQueue(capacity, bufSize int) *eventsQueue {
	bufs := make([]*bytesBuf, capacity)
	for i := 0; i < capacity; i++ {
		bufs[i] = &bytesBuf{
			buf:     make([]byte, bufSize),
			maxSize: uint32(bufSize),
		}
	}

	b := &eventsQueue{
		maxSize: bufSize,
	}
	atomic.StorePointer(&b.curBuf, unsafe.Pointer(bufs[0]))
	b.freeBufs = bufs[1:]

	return b
}

func (b *bytesBuf) WriteAcquire() {
	b.mu.RLock()
}

func (b *bytesBuf) WriteRelease() {
	b.mu.RUnlock()
}

func (b *bytesBuf) ReadAcquire() {
	b.mu.Lock()
}

func (b *bytesBuf) ReadRelease() {
	b.mu.Unlock()
}

func (b *bytesBuf) Alloc(size uint32) []byte {
	for {
		off := b.loadOff()
		if off+size > b.maxSize {
			return nil
		}

		if b.casOff(off, off+size) {
			return b.buf[off : off+size]
		}
	}
}

func (b *bytesBuf) Bytes() []byte {
	return b.buf[:b.off]
}

func (b *bytesBuf) Len() int64 {
	return int64(b.off)
}

func (b *bytesBuf) loadOff() uint32 {
	return atomic.LoadUint32(&b.off)
}

func (b *bytesBuf) casOff(old, new uint32) bool {
	return atomic.CompareAndSwapUint32(&b.off, old, new)
}

func (q *eventsQueue) WriteEvent(event *events.Event) (int, error) {
	eventSize := event.Size()
	msgSize := protoMetaLen + eventSize
	if msgSize > q.maxSize {
		return 0, fmt.Errorf("msg too big: %d > %d", msgSize, q.maxSize)
	}

	for {
		buf := q.loadCurBuf()
		if buf == nil {
			return 0, ErrFullQueue
		}
		buf.WriteAcquire()
		bufSpace := buf.Alloc(uint32(msgSize))
		if bufSpace != nil {
			// msg len
			binary.LittleEndian.PutUint32(bufSpace, uint32(eventSize))
			// proto msg
			_, err := event.MarshalTo(bufSpace[4:])
			// proto magic
			copy(bufSpace[4+eventSize:], protoMagic)

			buf.WriteRelease()
			return msgSize, err
		}

		buf.WriteRelease()
		q.putFullBuf(buf)
	}
}

func (q *eventsQueue) loadCurBuf() *bytesBuf {
	return (*bytesBuf)(atomic.LoadPointer(&q.curBuf))
}

func (q *eventsQueue) swapCurBuf(old, new *bytesBuf) bool {
	return atomic.CompareAndSwapPointer(&q.curBuf, unsafe.Pointer(old), unsafe.Pointer(new))
}

func (q *eventsQueue) putFullBuf(buf *bytesBuf) {
	q.mu.Lock()
	defer q.mu.Unlock()

	switch {
	case len(q.freeBufs) == 0:
		if q.swapCurBuf(buf, nil) {
			q.fullBufs = append(q.fullBufs, buf)
		}
		return
	case q.swapCurBuf(buf, q.freeBufs[0]):
		q.freeBufs = q.freeBufs[1:]
		q.fullBufs = append(q.fullBufs, buf)
	default:
		return
	}
}

func (q *eventsQueue) PopBuffer() (*bytesBuf, bool) {
	q.mu.Lock()
	if len(q.fullBufs) == 0 {
		q.mu.Unlock()
		// no full buffers - use current instead
		buf := q.loadCurBuf()
		if buf == nil || buf.loadOff() == 0 {
			return nil, true
		}

		q.putFullBuf(buf)
		q.mu.Lock()

		// wait active readers after swap
		buf.ReadAcquire()
		defer buf.ReadRelease()
	}

	defer q.mu.Unlock()
	buf := q.fullBufs[0]
	q.fullBufs = q.fullBufs[1:]
	return buf, len(q.fullBufs) == 0
}

func (q *eventsQueue) PushBuffer(buf *bytesBuf) {
	q.mu.Lock()
	defer q.mu.Unlock()

	buf.off = 0
	if len(q.freeBufs) == 0 && q.swapCurBuf(nil, buf) {
		return
	}

	q.freeBufs = append(q.freeBufs, buf)
}
