package queue

import (
	"context"
	"database/sql"
	"encoding/base64"
	"time"

	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/reflect/protoregistry"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/metrics"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/travel/budapest/metapms/internal/pgclient"
)

const defaultPrefetchLimit = 100

type Queue struct {
	name           string
	topic          string
	pg             *pgclient.PGClient
	logger         log.Logger
	handlers       map[protoreflect.FullName][]EventHandler
	persistedQueue *persistentQueue
	prefetchLimit  int
	protoTypes     protoregistry.Types
	metrics        *Metrics
}

type Metrics struct {
	UnprocessedEvents   metrics.Gauge
	Checks              metrics.Counter
	EventsHandled       metrics.Counter
	EventHandlingErrors metrics.Counter
}

type EventHandler func(tx *gorm.DB, payload proto.Message) error

func New(topic string, name string, pg *pgclient.PGClient, logger log.Logger, metrics metrics.Registry) *Queue {
	r := metrics.WithPrefix("queues").WithTags(map[string]string{"queue_name": name, "topic": topic})
	m := Metrics{
		UnprocessedEvents:   r.Gauge("unprocessed_events"),
		Checks:              r.Counter("checks"),
		EventsHandled:       r.Counter("events_handled"),
		EventHandlingErrors: r.Counter("event_handling_errors"),
	}

	q := &Queue{
		name:          name,
		topic:         topic,
		pg:            pg,
		logger:        logger,
		handlers:      map[protoreflect.FullName][]EventHandler{},
		prefetchLimit: defaultPrefetchLimit,
		metrics:       &m,
	}
	return q
}

func (q *Queue) Subscribe(message proto.Message, handler EventHandler) *Queue {
	messageType := message.ProtoReflect().Type()
	if err := q.protoTypes.RegisterMessage(messageType); err != nil {
		panic(err)
	}

	handlers, exist := q.handlers[messageType.Descriptor().FullName()]
	if !exist {
		handlers = []EventHandler{handler}
	} else {
		handlers = append(handlers, handler)
	}
	q.handlers[messageType.Descriptor().FullName()] = handlers
	return q
}

func (q *Queue) SetPrefetchLimit(limit int) *Queue {
	q.prefetchLimit = limit
	return q
}

func (q *Queue) loadPersistedState(tx *gorm.DB) error {
	if q.name == "" {
		if q.persistedQueue == nil {
			maxID, err := getMaxID(tx, q.topic)
			if err != nil {
				return xerrors.Errorf("unable to init queue's persisted state: %w", err)
			}
			q.persistedQueue = &persistentQueue{LastReadID: maxID}
		}
	} else {
		var state persistentQueue
		res := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
			Limit(1).
			Find(&state, persistentQueue{
				Name: q.name,
			})
		if res.Error != nil {
			return xerrors.Errorf("unable to get queue's persisted state: %w", res.Error)
		}
		if res.RowsAffected == 0 {
			m, err := getMaxID(tx, q.topic)
			if err != nil {
				return xerrors.Errorf("unable to init queue with initial value: %w", err)
			}
			state.LastReadID = m
			state.Name = q.name
		}
		q.persistedQueue = &state
	}
	return nil
}

func getMaxID(tx *gorm.DB, topic string) (int64, error) {
	var maxID sql.NullInt64
	if err := tx.Model(&event{}).Where(event{Topic: topic}).Select("max(id)").Scan(&maxID).Error; err != nil {
		return 0, xerrors.Errorf("unable to find latest event id for topic: %w", err)
	}
	if maxID.Valid {
		return maxID.Int64, nil
	} else {
		return 0, nil
	}
}

func (q *Queue) Listen(ctx context.Context, pollInterval time.Duration) error {
	ticker := time.NewTicker(pollInterval)
	for {
		select {
		case <-ticker.C:
			err := q.check(ctx)
			if err != nil {
				if pgclient.IsLockError(err) {
					continue
				}
				q.logger.Error("Error while trying to get new messages of topic "+q.topic, log.Error(err))
			}
		case <-ctx.Done():
			return nil
		}
	}
}

func (q *Queue) check(ctx context.Context) error {
	db, err := q.pg.GetPrimary()
	if err != nil {
		return xerrors.Errorf("unable to get db connection: %w", err)
	}
	return db.WithContext(ctx).Transaction(q.checkTx)
}

func (q *Queue) checkTx(tx *gorm.DB) error {
	q.metrics.Checks.Inc()
	var events []*event
	err := q.loadPersistedState(tx)
	if err != nil {
		return xerrors.Errorf("unable to load queue state: %w", err)
	}

	var count int64
	query := tx.Model(&event{}).
		Where(event{Topic: q.topic}).
		Where("id > ?", q.persistedQueue.LastReadID)
	if err := query.Count(&count).Error; err != nil {
		q.logger.Error("Error while counting events in the queue", log.Error(err))
	} else {
		q.metrics.UnprocessedEvents.Set(float64(count))
	}

	if err := query.
		Order("id").
		Limit(q.prefetchLimit).
		Find(&events).Error; err != nil {
		return xerrors.Errorf("unable to poll for events: %w", err)
	}
	for _, e := range events {
		eRes := tx.Transaction(func(t *gorm.DB) error {
			return q.handleEvent(t, e)
		})
		if eRes == nil {
			err := q.onEventHandled(tx, e)
			if err != nil {
				q.metrics.EventHandlingErrors.Inc()
				return xerrors.Errorf("unable to complete event handling: %w", err)
			}
			q.metrics.EventsHandled.Inc()
		} else {
			q.metrics.EventHandlingErrors.Inc()
			q.logger.Error("Error while handling event", log.Int64("EventID", e.ID), log.Error(eRes))
			return xerrors.Errorf("error while handling event: %w", eRes)
		}
	}
	if q.persistedQueue.ID == 0 && q.name != "" {
		if err := tx.Save(q.persistedQueue).Error; err != nil {
			return xerrors.Errorf("unable to update queue's persisted state: %w", err)
		}
	}
	return nil
}

func (q *Queue) handleEvent(tx *gorm.DB, e *event) (handleErr error) {
	defer func() {
		if handleErr != nil {
			q.logger.Error("Error while handling event",
				log.Int64("EventID", e.ID),
				log.String("Type", e.Type))
		} else {
			q.logger.Debug("Done handling event",
				log.Int64("EventID", e.ID),
				log.String("Type", e.Type))
		}
	}()
	protoName := protoreflect.FullName(e.Type)
	handlers, exist := q.handlers[protoName]
	if exist {
		messageType, err := q.protoTypes.FindMessageByName(protoName)
		if err != nil {
			return xerrors.Errorf("unable to find handler for event %s: %w", e.Type, err)
		}
		m := messageType.New().Interface()
		var decoded []byte
		if decoded, err = base64.StdEncoding.DecodeString(e.Payload); err != nil {
			return xerrors.Errorf("unable to decode payload: %w", err)
		}
		if err := proto.Unmarshal(decoded, m); err != nil {
			return xerrors.Errorf("unable to unmarshal payload: %w", err)
		}

		q.logger.Debug("Starting handling event",
			log.Int64("EventID", e.ID),
			log.String("Type", e.Type),
			log.Any("Payload", m),
		)
		for _, handler := range handlers {
			err := handler(tx, m)
			if err != nil {
				return err
			}
		}
	} else {
		q.logger.Debug("No handlers for event registered",
			log.Int64("EventID", e.ID),
			log.String("Type", e.Type))
	}
	return nil
}

func (q *Queue) onEventHandled(tx *gorm.DB, event *event) error {
	q.persistedQueue.LastReadID = event.ID
	if q.name != "" {
		if err := tx.Save(q.persistedQueue).Error; err != nil {
			return xerrors.Errorf("unable to update queue's persisted state: %w", err)
		}
	}
	return nil
}

func Push(tx *gorm.DB, topic string, message proto.Message) error {
	if message == nil {
		return xerrors.Errorf("Unexpected nil payload")
	}
	src, err := proto.Marshal(message)
	if err != nil {
		return xerrors.Errorf("unable to enqueue marshal event payload: %w", err)
	}

	e := event{
		Topic:   topic,
		Type:    string(message.ProtoReflect().Type().Descriptor().FullName()),
		Payload: base64.StdEncoding.EncodeToString(src),
	}
	if err := tx.Create(&e).Error; err != nil {
		return xerrors.Errorf("unable to enqueue notification of type '%s' to topic '%s': %w", e.Type, topic, err)
	}
	return nil
}
