package protoqueue

import (
	"context"
	"fmt"
	"time"

	"github.com/klauspost/compress/zstd"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/gideon/gideon/internal/sensors"
	"a.yandex-team.ru/security/gideon/gideon/pkg/events"
)

var (
	// https://wiki.yandex-team.ru/logfeller/splitter/protoseq/
	protoMagic = []byte{
		0x1F, 0xF7, 0xF7, 0x7E, 0xBE, 0xA6, 0x5E, 0x9E, 0x37, 0xA6, 0xF6, 0x2E, 0xFE, 0xAE, 0x47, 0xA7, 0xB7, 0x6E,
		0xBF, 0xAF, 0x16, 0x9E, 0x9F, 0x37, 0xF6, 0x57, 0xF7, 0x66, 0xA7, 0x06, 0xAF, 0xF7,
	}

	protoMetaLen = 4 + len(protoMagic)
)

type Sender func(ctx context.Context, data []byte) error

type ProtoQueue struct {
	q *eventsQueue

	flushInterval   time.Duration
	writeBufferSize int
	maxSize         int
	sender          Sender

	outBuf   []byte
	zstd     *zstd.Encoder
	compress func([]byte) []byte

	ctx       context.Context
	ctxCancel context.CancelFunc
	close     chan struct{}
	closed    chan struct{}

	log     log.Logger
	sensors sensors.Sensor
}

func NewProtoQueue(sender Sender, opts ...Option) *ProtoQueue {
	encoder, err := zstd.NewWriter(nil, zstd.WithEncoderConcurrency(1), zstd.WithEncoderLevel(zstd.SpeedDefault))
	if err != nil {
		panic(fmt.Sprintf("failed to create zstd encoder: %s", err))
	}

	ctx, cancel := context.WithCancel(context.Background())
	q := &ProtoQueue{
		closed:          make(chan struct{}),
		close:           make(chan struct{}),
		ctx:             ctx,
		ctxCancel:       cancel,
		flushInterval:   500 * time.Millisecond,
		maxSize:         1 << 20,
		writeBufferSize: 1 << 16,
		zstd:            encoder,
		sender:          sender,
		log:             &nop.Logger{},
		sensors:         &sensors.NopSensor{},
	}

	q.compress = q.zstdCompression
	for _, opt := range opts {
		opt(q)
	}

	q.outBuf = make([]byte, 0, q.writeBufferSize)
	q.q = NewBytesQueue(q.maxSize/q.writeBufferSize, q.writeBufferSize)
	go q.run()

	return q
}

func (pq *ProtoQueue) onError(err error) {
	pq.sensors.CollectorErrors(1)
	pq.log.Error("protoqueue sender fail", log.Error(err))
}

func (pq *ProtoQueue) run() {
	defer close(pq.closed)

	flush := time.NewTicker(pq.flushInterval)
	defer flush.Stop()

	for {
		select {
		case <-pq.ctx.Done():
			return
		case <-pq.close:
			return
		case <-flush.C:
		}

		for {
			buf, isLast := pq.q.PopBuffer()
			if buf == nil {
				break
			}

			size := buf.Len()
			if size == 0 {
				pq.q.PushBuffer(buf)
				break
			}

			pq.sensors.CollectorBatch(1)
			if err := pq.sender(pq.ctx, pq.compress(buf.Bytes())); err != nil {
				pq.onError(err)
			}

			pq.sensors.CollectorSize(-size)
			pq.q.PushBuffer(buf)

			if isLast {
				break
			}
		}
	}
}

func (pq *ProtoQueue) EnqueueEvent(event *events.Event) error {
	size, err := pq.q.WriteEvent(event)
	if err != nil {
		if err == ErrFullQueue {
			pq.sensors.CollectorDrops(1)
			return nil
		}
		return err
	}

	pq.sensors.CollectorSize(int64(size))
	return nil
}

func (pq *ProtoQueue) Stop(ctx context.Context) {
	defer pq.ctxCancel()
	close(pq.close)

	select {
	case <-pq.closed:
		// completed normally
	case <-ctx.Done():
		// timed out
	}
}

func (pq *ProtoQueue) nopCompression(in []byte) []byte {
	targetLen := len(in)
	if cap(pq.outBuf) < targetLen {
		pq.outBuf = make([]byte, 0, targetLen)
	} else {
		pq.outBuf = pq.outBuf[:targetLen]
	}

	copy(pq.outBuf, in)
	return pq.outBuf
}

func (pq *ProtoQueue) zstdCompression(in []byte) []byte {
	targetLen := len(in)
	if cap(pq.outBuf) < targetLen {
		pq.outBuf = make([]byte, 0, targetLen)
	} else {
		pq.outBuf = pq.outBuf[:0]
	}

	return pq.zstd.EncodeAll(in, pq.outBuf)
}
