package database

import (
	"fmt"
	"sync"

	"gorm.io/gorm"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/travel/library/go/metrics"
	"a.yandex-team.ru/travel/notifier/internal/models"
)

type notificationTransaction struct {
	locker       sync.Mutex
	tx           *gorm.DB
	notification *models.Notification
	isClosed     bool
}

func (nt *notificationTransaction) Update(notification models.Notification, logger log.Logger) error {
	nt.locker.Lock()
	defer nt.locker.Unlock()
	if nt.isClosed {
		logger.Info("notification won't be updated due to closed transaction", log.UInt64("notificationID", notification.ID))
		return ErrTxClosed
	}
	if nt.notification.ID != notification.ID {
		logger.Info(
			"notification won't be updated due to changed id",
			log.UInt64("previousID", nt.notification.ID),
			log.UInt64("notificationID", notification.ID),
		)
		return errUpdateDifferentRow
	}
	*nt.notification = notification
	return nil
}

var (
	ErrTxClosed           = fmt.Errorf("the transaction has already been closed")
	errUpdateDifferentRow = fmt.Errorf("the transaction is available only for the initial notification")
)

func (nt *notificationTransaction) Commit(logger log.Logger) (err error) {
	if !nt.isClosed {
		nt.locker.Lock()
		defer nt.locker.Unlock()
		if !nt.isClosed {
			nt.isClosed = true
			defer func() {
				if err != nil {
					logger.Error("failed to commit transaction", log.UInt64("notificationID", nt.notification.ID), log.Error(err))
				} else {
					logger.Info("transaction has been commited", log.UInt64("notificationID", nt.notification.ID))
					metrics.GlobalAppMetrics().GetOrCreateGauge("transactions", nil, "in_progress").Add(-1)
				}
			}()
			result := nt.tx.Save(nt.notification)
			if result.Error != nil {
				return result.Error
			}
			return result.Commit().Error
		} else {
			return ErrTxClosed
		}
	}
	return ErrTxClosed
}
func (nt *notificationTransaction) Rollback(logger log.Logger) (err error) {
	if !nt.isClosed {
		nt.locker.Lock()
		defer nt.locker.Unlock()
		if !nt.isClosed {
			nt.isClosed = true
			defer func() {
				if err != nil {
					logger.Error("failed to rollback transaction", log.UInt64("notificationID", nt.notification.ID), log.Error(err))
				} else {
					logger.Info("transaction has been rolled back", log.UInt64("notificationID", nt.notification.ID))
					metrics.GlobalAppMetrics().GetOrCreateGauge("transactions", nil, "in_progress").Add(-1)
				}
			}()
			return nt.tx.Rollback().Error
		} else {
			return ErrTxClosed
		}
	}
	return ErrTxClosed
}
