package logbroker

import (
	"bytes"
	"context"
	"fmt"
	"math/rand"
	"sync/atomic"
	"time"

	"github.com/klauspost/compress/zstd"

	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/csp-report/internal/logbroker/message"
	"a.yandex-team.ru/security/libs/go/lblight"
	"a.yandex-team.ru/security/libs/go/ydbtvm"
)

var flushInterval = func() time.Duration {
	rand.Seed(time.Now().UnixNano())
	return time.Duration(rand.Intn(50))*time.Millisecond + 2*time.Second
}()

type lightWriter struct {
	writer       *lblight.Writer
	source       string
	logger       log.Logger
	initialSeqNo uint64
	errCount     uint64
	writtenCount uint64
	inflightSize uint64
	maxSize      uint64
	msgChan      <-chan *message.Message
	exitChan     chan struct{}
	zstd         *zstd.Encoder
	ctx          context.Context
	cancelCtx    context.CancelFunc
}

func newLightWriter(
	ctx context.Context,
	source string,
	cfg Config,
	msgChan <-chan *message.Message,
) (*lightWriter, error) {

	encoder, err := zstd.NewWriter(nil, zstd.WithEncoderConcurrency(1), zstd.WithEncoderLevel(zstd.SpeedBestCompression))
	if err != nil {
		return nil, fmt.Errorf("failed to create zstd encoder: %s", err)
	}

	logger := log.With(cfg.Logger.Logger(), log.String("data_source", source))
	var cp ydb.Credentials
	if cfg.TvmClient != nil {
		cp = &ydbtvm.TvmCredentials{
			DstID:     ydbtvm.LbClientID,
			TvmClient: cfg.TvmClient,
		}
	}

	writer, err := lblight.NewWriter(
		cfg.Topic,
		lblight.WriterWithEndpoint(cfg.Endpoint),
		lblight.WriterWithLogger(logger),
		lblight.WriterWithCreds(cp),
		lblight.WriterWithSourceID(fmt.Sprintf("%s/%s", cfg.SourceID, source)),
		lblight.WriterWithCodec(lblight.CodecZstd),
		lblight.WriterWithRetries(10),
	)
	if err != nil {
		_ = encoder.Close()
		return nil, fmt.Errorf("failed to create writer: %w", err)
	}

	err = writer.Dial(ctx)
	if err != nil {
		_ = encoder.Close()
		_ = writer.Close(ctx)
		return nil, fmt.Errorf("failed to initialize lb writer: %w", err)
	}

	sessionInfo := writer.SessionInfo()
	logger.Info("LB writer initialized", log.Any("init_info", sessionInfo))

	ctx, cancelFn := context.WithCancel(context.Background())
	return &lightWriter{
		writer:       writer,
		source:       source,
		logger:       logger,
		initialSeqNo: sessionInfo.NextSeqNo,
		msgChan:      msgChan,
		exitChan:     make(chan struct{}),
		zstd:         encoder,
		ctx:          ctx,
		cancelCtx:    cancelFn,
		maxSize:      cfg.WriterMem,
	}, nil
}

func (w *lightWriter) Start() {
	go w.recvResponses()
	go w.writerLoop()
}

func (w *lightWriter) Shutdown(ctx context.Context) error {
	<-w.exitChan
	_ = w.zstd.Close()
	err := w.writer.Close(ctx)
	w.cancelCtx()
	return err
}

func (w *lightWriter) FlushStat() Stat {
	return Stat{
		Inflight: atomic.LoadUint64(&w.inflightSize),
		Errors:   atomic.SwapUint64(&w.errCount, 0),
		Written:  atomic.SwapUint64(&w.writtenCount, 0),
	}
}

func (w *lightWriter) writerLoop() {
	defer close(w.exitChan)

	var (
		msgBuf  bytes.Buffer
		zstdBuf = make([]byte, 0, 512<<10)
	)

	collect := func(msg *message.Message) uint64 {
		if msg.SeqNo != 0 && msg.SeqNo < w.initialSeqNo {
			w.logger.Warn("ignore old message",
				log.UInt64("seq_no", msg.SeqNo),
			)
			// skip old messages, probably we read it from disk
			return uint64(msgBuf.Len())
		}

		msg.WriteTo(&msgBuf)
		msgBuf.WriteByte('\n')
		message.ReleaseMsg(msg)
		return uint64(msgBuf.Len())
	}

	send := func() {
		data := msgBuf.Bytes()
		if len(data) == 0 {
			return
		}

		// drop trailing \n
		data = data[:len(data)-1]
		if cap(zstdBuf) < len(data) {
			zstdBuf = make([]byte, 0, len(data))
		}

		err := w.writer.Write(w.ctx, lblight.WriteMsg{
			Data: w.zstd.EncodeAll(data, zstdBuf[:0]),
		})
		if err != nil {
			w.logger.Error("logbroker terminated: write fault", log.Error(err))
		}

		msgBuf.Reset()
	}

	flush := time.NewTicker(flushInterval)
	defer flush.Stop()

	for {
		select {
		case msg, ok := <-w.msgChan:
			if !ok {
				send()
				return
			}

			curSize := collect(msg)
			if curSize > w.maxSize && w.maxSize != 0 {
				send()
				atomic.StoreUint64(&w.inflightSize, 0)
			} else {
				atomic.StoreUint64(&w.inflightSize, curSize)
			}

		case <-flush.C:
			send()
			atomic.StoreUint64(&w.inflightSize, 0)
		}
	}
}

func (w *lightWriter) recvResponses() {
	for {
		msg, err := w.writer.FetchFeedback()
		if err == lblight.ErrClosed {
			return
		}

		if err != nil {
			w.logger.Error("logbroker feedback fetch fail", log.Error(err))
			continue
		}

		switch m := msg.(type) {
		case *lblight.FeedbackMsgAck:
			atomic.AddUint64(&w.writtenCount, 1)
		case *lblight.FeedbackMsgAckBatch:
			atomic.AddUint64(&w.writtenCount, uint64(len(m.ACKs)))
		case *lblight.FeedbackMsgError:
			atomic.AddUint64(&w.errCount, 1)
			w.logger.Error("logbroker issue respond", log.Error(m.Error))
		default:
		}
	}
}
