package logbroker

import (
	"context"
	"fmt"
	"sync"
	"time"

	"github.com/golang/protobuf/proto"
	"go.uber.org/atomic"

	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue"
	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue/log/corelogadapter"
	"a.yandex-team.ru/kikimr/public/sdk/go/persqueue/recipe"
	"a.yandex-team.ru/kikimr/public/sdk/go/ydb"
	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/metrics/mock"
	"a.yandex-team.ru/travel/library/go/metrics"
)

type Channel struct {
	logger      log.Logger
	rawMessages chan []byte
	consumer    *Consumer
}

type Consumer struct {
	logger            log.Logger
	readerOptionsList []persqueue.ReaderOptions
	ctx               context.Context
	channels          []*Channel
	channelsMutex     sync.RWMutex
	activeReadersCnt  atomic.Int64
	appMetricsPrefix  string
	appMetrics        *metrics.AppMetrics
	lockGroup         *YtLockGroup
}

type CredentialsProvider interface {
	Credentials() (ydb.Credentials, error)
}

func NewConsumer(
	clusters []string,
	endpoint string,
	topic string,
	lbConsumer string,
	readTimestamp time.Time,
	credentialsProvider CredentialsProvider,
	logger log.Logger,
) (*Consumer, error) {
	return NewConsumerWithLbConsumerGenerating(
		clusters, endpoint, topic, lbConsumer, readTimestamp,
		credentialsProvider, logger, nil,
	)
}

const (
	libraryPath           = "library.go.logbroker"
	readerChannelCapacity = 10000
	readerReconnectWait   = 15 * time.Second
	readerOverloadWait    = 15 * time.Second
	lockAcquireWait       = 5 * time.Second
)

func NewConsumerWithLbConsumerGenerating(
	clusters []string,
	endpoint string,
	topic string,
	lbConsumer string,
	readTimestamp time.Time,
	credentialsProvider CredentialsProvider,
	logger log.Logger,
	lockGroup *YtLockGroup,
) (*Consumer, error) {

	const funcName = libraryPath + ".NewConsumerWithLbConsumerGenerating"

	credentials, err := credentialsProvider.Credentials()
	if err != nil {
		return nil, fmt.Errorf("%s: can no get credentials: %w", funcName, err)
	}

	var readerOptionsList []persqueue.ReaderOptions
	topicInfo := persqueue.TopicInfo{Topic: topic}
	if len(clusters) == 0 {
		return nil, fmt.Errorf("%s: no clusters provided", funcName)
	}
	for _, cluster := range clusters {
		clusterEndpoint := fmt.Sprintf("%s.%s", cluster, endpoint)
		readerOptionsList = append(readerOptionsList, persqueue.ReaderOptions{
			Endpoint:              clusterEndpoint,
			Logger:                corelogadapter.New(logger),
			Consumer:              lbConsumer,
			Topics:                []persqueue.TopicInfo{topicInfo},
			Credentials:           credentials,
			ReadTimestamp:         readTimestamp,
			DecompressionDisabled: true,
			RetryOnFailure:        false,
			CommitsDisabled:       true,
			ReadOnlyLocal:         true,
			MaxReadMessagesCount:  readerChannelCapacity / uint32(len(clusters)),
		})
	}

	consumer := Consumer{
		logger:            logger,
		readerOptionsList: readerOptionsList,
		lockGroup:         lockGroup,
		appMetrics:        metrics.NewAppMetrics(mock.NewRegistry(mock.NewRegistryOpts())),
	}
	return &consumer, nil
}

func NewConsumerWithRecipe(env *recipe.Env, clusters []string, topic string, logger log.Logger) (*Consumer, error) {
	consumerOptions := env.ConsumerOptions()
	consumer, err := NewConsumer(clusters, consumerOptions.Endpoint, topic, consumerOptions.Consumer,
		time.Time{}, NewTestCredentialsProvider(env), logger)
	if err != nil {
		return nil, err
	}
	for i := range consumer.readerOptionsList {
		consumer.readerOptionsList[i].Endpoint = consumerOptions.Endpoint
		consumer.readerOptionsList[i].Port = consumerOptions.Port
		consumer.readerOptionsList[i].Topics = []persqueue.TopicInfo{{Topic: fmt.Sprintf("%s%s", clusters[i], topic)}}
		consumer.readerOptionsList[i].DecompressionDisabled = consumerOptions.DecompressionDisabled
		consumer.readerOptionsList[i].Consumer = consumerOptions.Consumer
	}
	return consumer, nil
}

func (c *Consumer) RegisterMetrics(appMetrics *metrics.AppMetrics, prefix string) {
	c.appMetricsPrefix = prefix
	c.appMetrics = appMetrics
}

func (c *Consumer) newReader(readerOptions persqueue.ReaderOptions) (persqueue.Reader, error) {
	reader := persqueue.NewReader(readerOptions)
	c.logger.Info("consumer starting",
		log.String("Consumer", readerOptions.Consumer),
		log.String("Endpoint", readerOptions.Endpoint),
		log.String("Topic", readerOptions.Topics[0].Topic))

	init, err := reader.Start(c.ctx)
	if err != nil {
		return nil, err
	}
	c.logger.Info("consumer started",
		log.String("SessionID", init.SessionID),
		log.String("Consumer", readerOptions.Consumer),
	)
	return reader, nil
}

func (c *Consumer) readerLoop(ctx context.Context, optionsIndex int, consumerNumber uint) {
	const funcName = libraryPath + ".readerLoop"

	readerOptions := c.readerOptionsList[optionsIndex]
	if consumerNumber != 0 {
		readerOptions.Consumer = fmt.Sprintf("%s-%d", readerOptions.Consumer, consumerNumber)
	}
	reader, err := c.newReader(readerOptions)
	if err != nil {
		c.logger.Errorf("reader run error: %s. Topic='%s'. Going to restart", err.Error(), readerOptions.Topics[0].Topic)
		if ctx.Err() != nil {
			c.logger.Errorf("%s: context closed: %s", funcName, ctx.Err().Error())
			return
		}
		time.Sleep(readerReconnectWait)
		go c.readerLoop(ctx, optionsIndex, consumerNumber)
		return
	}

	metricsTags := map[string]string{
		"endpoint": readerOptions.Endpoint,
		"topic":    readerOptions.Topics[0].Topic,
	}
	updateAppMetrics := func(active bool) {
		var cnt float64
		if active {
			cnt = 1
			c.activeReadersCnt.Inc()
		} else {
			cnt = 0
			c.activeReadersCnt.Dec()
		}
		c.appMetrics.GetOrCreateGauge(c.appMetricsPrefix, metricsTags, "active_readers").Set(cnt)
	}
	updateAppMetrics(true)
	defer updateAppMetrics(false)

	for {
		select {
		case <-ctx.Done():
			reader.Shutdown()
			return
		case <-reader.Closed():
			c.logger.Errorf("%s: reader closed: %s", funcName, reader.Err().Error())
			if ctx.Err() != nil {
				c.logger.Errorf("%s: context closed: %s", funcName, ctx.Err().Error())
				return
			}
			time.Sleep(readerReconnectWait)
			go c.readerLoop(ctx, optionsIndex, consumerNumber)
			return
		case lbMessage := <-reader.C():
			switch v := lbMessage.(type) {
			case *persqueue.Data:
			Loop:
				for _, b := range v.Batches() {
					for _, m := range b.Messages {
						c.readerOptionsList[optionsIndex].ReadTimestamp = m.CreateTime
						c.channelsMutex.RLock()
						for _, ch := range c.channels {
							data := make([]byte, len(m.Data))
							copy(data, m.Data)
							select {
							case ch.rawMessages <- data:
								c.appMetrics.GetOrCreateCounter(
									c.appMetricsPrefix, metricsTags, "received_messages").Add(1)
							case <-ctx.Done():
								c.channelsMutex.RUnlock()
								break Loop
							default:
								c.logger.Errorf("%s: channel is full, message skipped for %v", funcName, readerOptions.Topics)
								c.appMetrics.GetOrCreateCounter(
									c.appMetricsPrefix, metricsTags, "skipped_messages").Add(1)
								time.Sleep(readerOverloadWait)
							}
						}
						c.channelsMutex.RUnlock()
					}
				}
			}
		}
	}
}

func (c *Consumer) Run(ctx context.Context) error {
	const funcName = libraryPath + ".Run"

	if c.ctx != nil {
		return fmt.Errorf("%s: secondary run", funcName)
	}
	c.ctx = ctx

	var (
		shutdownReaders context.CancelFunc
		consumerNumber  uint
	)

	if c.lockGroup != nil {
		WithOnLostCallbacks(func() {
			c.logger.Infof("%s: lock is lost. updating", funcName)

			shutdownReaders()

			for {
				if err := c.lockGroup.Acquire(c.ctx); err != nil {
					c.logger.Errorf("%s: can not update lock: %v", funcName, err)
					time.Sleep(lockAcquireWait)
				} else {
					break
				}
			}

			shutdownReaders = c.spawnReaders(c.lockGroup.LockedNumber())
		})(c.lockGroup)

		if err := c.lockGroup.Acquire(c.ctx); err != nil {
			return fmt.Errorf("%s: can not acquire lock: %w", funcName, err)
		}
		consumerNumber = c.lockGroup.LockedNumber()
	}

	shutdownReaders = c.spawnReaders(consumerNumber)

	return nil
}

func (c *Consumer) spawnReaders(consumerNumber uint) context.CancelFunc {
	childCtx, shutdownReaders := context.WithCancel(c.ctx)
	for i := range c.readerOptionsList {
		go c.readerLoop(childCtx, i, consumerNumber)
	}
	return shutdownReaders
}

func (c *Consumer) Err() error {
	if c.activeReadersCnt.Load() < 1 {
		return fmt.Errorf("Consumer.Err: no active readers")
	}
	return nil
}

func (c *Consumer) NewChannel() *Channel {
	c.channelsMutex.Lock()
	defer c.channelsMutex.Unlock()
	ch := &Channel{
		logger:      c.logger,
		consumer:    c,
		rawMessages: make(chan []byte, readerChannelCapacity),
	}
	c.channels = append(c.channels, ch)
	return ch
}

func (c *Consumer) CloseChannel(channel *Channel) error {
	c.channelsMutex.Lock()
	defer c.channelsMutex.Unlock()
	idx := -1
	for i, ch := range c.channels {
		if ch == channel {
			idx = i
			break
		}
	}
	if idx == -1 {
		return fmt.Errorf("unknown channel")
	}
	lastIdx := len(c.channels) - 1
	c.channels[idx] = c.channels[lastIdx]
	c.channels = c.channels[:lastIdx]
	channel.close()
	return nil
}

func (ch *Channel) close() {
	close(ch.rawMessages)
}

type TimeoutError struct{}

func (TimeoutError) Error() string { return "execution deadline exceeded" }

func (ch *Channel) ReadWithDeadline(msg proto.Message, deadline time.Time) error {
	const funcName = libraryPath + ".ReadWithDeadline"

	if ch.consumer.ctx == nil {
		return fmt.Errorf("Channel.ReadWithDeadline fails: consumer not started")
	}
	var timeoutChan <-chan time.Time
	if !deadline.IsZero() {
		now := time.Now()
		if now.After(deadline) {
			return TimeoutError{}
		}
		timeout := deadline.Sub(now)
		timeoutChan = time.After(timeout)
	}
	for {
		select {
		case <-ch.consumer.ctx.Done():
			return fmt.Errorf("consumer stopped: %w", ch.consumer.ctx.Err())
		case <-timeoutChan:
			return TimeoutError{}
		case data := <-ch.rawMessages:
			err := proto.Unmarshal(data, msg)
			if err != nil {
				ch.logger.Errorf("%s: can not unmarshal message: %s", funcName, err.Error())
				continue
			}
			return nil
		}
	}
}

func (ch *Channel) Read(msg proto.Message) error {
	return ch.ReadWithDeadline(msg, time.Time{})
}
