package logbroker

import (
	"bytes"
	"context"
	"sync"
	"time"

	"github.com/cenkalti/backoff/v4"

	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue"
	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue/log/corelogadapter"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/avia/library/go/logbroker"
	"a.yandex-team.ru/travel/library/go/metrics"
)

type SingleEndpointConsumer struct {
	baseReader *logbroker.Reader
	logger     log.Logger
	name       string
	processor  MessageProcessor
}

type TopicConfig struct {
	Path                 string `yaml:"path"`
	Consumer             string `yaml:"consumer"`
	MaxReadSize          uint32 `yaml:"max_read_size"`
	MaxReadMessagesCount uint32 `yaml:"max_read_messages_count"`
}

func NewSingleEndpointConsumer(
	ctx context.Context,
	config ConsumerConfig,
	token string,
	endpoint string,
	logger log.Logger,
	name string,
	processor MessageProcessor,
) (*SingleEndpointConsumer, error) {
	readerOptions := persqueue.ReaderOptions{
		Endpoint:              endpoint,
		Credentials:           ydb.AuthTokenCredentials{AuthToken: token},
		Consumer:              config.Topic.Consumer,
		Logger:                corelogadapter.New(logger.WithName(name)),
		Topics:                []persqueue.TopicInfo{{Topic: config.Topic.Path}},
		MaxReadSize:           config.Topic.MaxReadSize,
		MaxReadMessagesCount:  config.Topic.MaxReadMessagesCount,
		DecompressionDisabled: false,
		RetryOnFailure:        true,
	}

	var err error
	consumer := &SingleEndpointConsumer{logger: logger, name: name, processor: processor}
	consumer.baseReader, err = logbroker.NewReader(
		ctx,
		readerOptions,
		logbroker.WithOnBatchReceived(consumer.onBatchReceived),
	)
	if err != nil {
		return nil, err
	}
	return consumer, nil
}

func (c *SingleEndpointConsumer) onBatchReceived(batch persqueue.MessageBatch) {
	var messages [][]byte
	for _, m := range batch.Messages {
		c.onMessage(m)
		for _, message := range bytes.Split(m.Data, []byte("\n")) {
			if len(message) > 0 {
				messages = append(messages, message)
			}
		}
	}
	wg := sync.WaitGroup{}
	wg.Add(len(messages))
	for _, m := range messages {
		c.countMessageMetric("total")
		go func(message []byte) {
			defer wg.Done()
			c.processor.OnMessage(message)
		}(m)
	}
	wg.Wait()
}

func (c *SingleEndpointConsumer) onMessage(message persqueue.ReadMessage) {
	c.logger.Debug(
		"Received message",
		log.Any("seqNo", message.SeqNo),
		log.Any("source-id", string(message.SourceID)),
		log.Any("created", message.CreateTime),
		log.Any("written", message.WriteTime),
		log.Any("codec", message.Codec),
	)
}

func (c *SingleEndpointConsumer) countMessageMetric(metricType string) {
	metrics.GlobalAppMetrics().GetOrCreateCounter(
		metricsPrefix,
		map[string]string{"type": metricType, "name": c.name},
		receievedMetricName,
	).Inc()
}

func (c *SingleEndpointConsumer) Read() error {
	return backoff.RetryNotify(
		c.baseReader.Read,
		backoff.WithMaxRetries(backoff.NewConstantBackOff(10*time.Second), 3),
		func(err error, _ time.Duration) {
			c.logger.Error("failed to read", log.Error(err))
		},
	)
}
