package logbroker

import (
	"context"
	"fmt"
	"sync"

	"github.com/golang/protobuf/proto"
	"go.uber.org/atomic"

	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue"
	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue/log/corelogadapter"
	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue/recipe"
	"a.yandex-team.ru/library/go/core/log"
)

const (
	newProducerFn   = "library.go.logbroker.NewProducer"
	producerRunFn   = "library.go.logbroker.Producer.Run"
	producerWriteFn = "library.go.logbroker.Producer.Write"
)

type Producer struct {
	logger    log.Logger
	writer    persqueue.Writer
	seqNo     uint64
	ctx       context.Context
	mutex     sync.Mutex
	closed    *atomic.Bool
	withSeqNo bool
}

func NewProducer(
	topic string,
	endpoint string,
	sourceID string,
	credentialsProvider CredentialsProvider,
	logger log.Logger,
	options ...ProducerOption,
) (*Producer, error) {
	var writerOptions persqueue.WriterOptions
	credentials, err := credentialsProvider.Credentials()
	if err != nil {
		return nil, fmt.Errorf("%s: credentials provider error: %w", newProducerFn, err)
	}
	writerOptions = persqueue.WriterOptions{
		Endpoint:       endpoint,
		Logger:         corelogadapter.New(logger),
		Topic:          topic,
		SourceID:       []byte(sourceID),
		Credentials:    credentials,
		RetryOnFailure: true,
		Codec:          persqueue.Raw,
	}
	p := Producer{
		logger:    logger,
		writer:    persqueue.NewWriter(writerOptions),
		closed:    atomic.NewBool(true),
		withSeqNo: true,
	}

	for _, option := range options {
		option(&p)
	}

	return &p, nil
}

func NewProducerWithRecipe(env *recipe.Env, cluster string, topic string, logger log.Logger) *Producer {
	writerOptions := env.ProducerOptions()
	writerOptions.Topic = fmt.Sprintf("%s%s", cluster, topic)
	return &Producer{
		logger: logger,
		writer: persqueue.NewWriter(writerOptions),
		closed: atomic.NewBool(true),
	}
}

func (p *Producer) Run(ctx context.Context) error {
	p.ctx = ctx
	init, err := p.writer.Init(p.ctx)
	if err != nil {
		return fmt.Errorf("%s, producer run error: %w", producerRunFn, err)
	}
	p.seqNo = init.MaxSeqNo + 1
	p.closed.Store(false)
	go p.loop()
	p.logger.Info(
		"Producer started",
		log.String("topic", init.Topic),
		log.String("cluster", init.Cluster),
		log.UInt64("max_seq_no", init.MaxSeqNo),
		log.String("session_id", init.SessionID),
		log.UInt64("partition", init.Partition),
	)
	return nil
}

func (p *Producer) loop() {
	for {
		select {
		case <-p.writer.C():
		case <-p.writer.Closed():
			p.closed.Store(true)
			return
		}
	}
}

func (p *Producer) Write(msg proto.Message) error {
	if p.withSeqNo {
		return p.writeWithSeqNo(msg)
	}
	return p.writeWithoutSeqNo(msg)
}

func (p *Producer) writeWithoutSeqNo(msg proto.Message) error {
	data, err := proto.Marshal(msg)
	if err != nil {
		return fmt.Errorf("%s: write error: %w", producerWriteFn, err)
	}
	lbMessage := persqueue.WriteMessage{Data: data}
	if p.ctx.Err() != nil {
		return fmt.Errorf("%s: context error: %w", producerWriteFn, err)
	}
	err = p.writer.Write(&lbMessage)
	if err != nil {
		return fmt.Errorf("%s: write error: %w", producerWriteFn, err)
	}
	return nil
}

func (p *Producer) writeWithSeqNo(msg proto.Message) error {
	p.mutex.Lock()
	defer p.mutex.Unlock()

	data, err := proto.Marshal(msg)
	if err != nil {
		return fmt.Errorf("%s: write error: %w", producerWriteFn, err)
	}
	p.seqNo++
	p.logger.Debug("Sent message", log.UInt64("seq_no", p.seqNo))
	lbMessage := persqueue.WriteMessage{Data: data}
	lbMessage.WithSeqNo(p.seqNo)
	if p.ctx.Err() != nil {
		return fmt.Errorf("%s: context error: %w", producerWriteFn, err)
	}
	err = p.writer.Write(&lbMessage)
	if err != nil {
		return fmt.Errorf("%s: write error: %w", producerWriteFn, err)
	}
	return nil
}

func (p *Producer) Stat() persqueue.WriterStat {
	return p.writer.Stat()
}

func (p *Producer) Closed() bool {
	return p.closed.Load()
}

func (p *Producer) Close() error {
	return p.writer.Close()
}
