package pgclient

import (
	"context"
	"database/sql"
	"log"
	"os"
	"time"

	"github.com/jackc/pgconn"
	"golang.yandex/hasql"
	"golang.yandex/hasql/checkers"
	"gorm.io/driver/postgres"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"

	appLog "a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
)

const UnableToLockCode = "55P03"

var AdvisoryLockErr = xerrors.NewSentinel("unable to acquire advisory lock")

func gormWrapper(node hasql.Node, logLevel logger.LogLevel) (*gorm.DB, error) {
	return gorm.Open(
		postgres.New(
			postgres.Config{
				Conn: node.DB(),
			},
		), &gorm.Config{
			Logger: logger.New(
				log.New(os.Stdout, "\r\n", log.LstdFlags),
				logger.Config{SlowThreshold: 2 * time.Second, LogLevel: logLevel},
			),
		},
	)
}

func (c *PGClient) nodeChecker(ctx context.Context, db *sql.DB) (bool, error) {
	start := time.Now()
	isPrimary, err := checkers.Check(ctx, db, "SELECT NOT pg_is_in_recovery()")
	latency := time.Since(start)
	go c.onCheckedNode(db, latency)
	return isPrimary, err
}

func IsLockError(err error) bool {
	pgError := &pgconn.PgError{}
	if xerrors.As(err, &pgError) {
		return pgError.Code == UnableToLockCode
	}
	if xerrors.Is(err, AdvisoryLockErr) {
		return true
	}
	return false
}

func runHealthCheckLoop(ctx context.Context, db *gorm.DB, logger appLog.Logger, lockID int, alreadyCancelled *bool) context.Context {
	resultCtx, cancel := context.WithCancel(ctx)
	ticker := time.NewTicker(time.Second * 5)
	go func() {
		defer func() {
			*alreadyCancelled = true
			cancel()
		}()
		logger.Debug("Started health check loop", appLog.Int("LockID", lockID))
		for {
			select {
			case <-ticker.C:
				var ret int
				if err := db.Raw("SELECT 1").
					Scan(&ret).Error; err != nil {
					logger.Error("Lock: unable to check DB health", appLog.Int("LockID", lockID), appLog.Error(err))
					return
				}
				if ret != 1 {
					logger.Error("Lock: unexpected result on ping", appLog.Int("LockID", lockID))
					return
				}
				logger.Debug("Lock: tx is alive", appLog.Int("LockID", lockID))
			case <-ctx.Done():
				logger.Debug("Stopping lock ping loop", appLog.Int("LockID", lockID))
				return
			}
		}
	}()
	return resultCtx
}

func WithLock(ctx context.Context, pg *PGClient, logger appLog.Logger, lockID int, toRun func(ctx context.Context, tx *gorm.DB) error) error {
	var acquired bool
	db, err := pg.GetPrimary()
	if err != nil {
		return xerrors.Errorf("unable to connect to db to take a lock: %w", err)
	}
	return db.Connection(func(tx *gorm.DB) error {
		if err := tx.Raw("SELECT pg_try_advisory_lock(?)", lockID).Scan(&acquired).Error; err != nil {
			return xerrors.Errorf("error while acquiring lock: %w", err)
		}
		if !acquired {
			logger.Debug("Lock: unable to acquire", appLog.Int("LockID", lockID))
			return AdvisoryLockErr
		} else {
			logger.Debug("Lock: acquired", appLog.Int("LockID", lockID))
		}
		pingerContext, pingerCancel := context.WithCancel(ctx)
		var alreadyCancelled bool
		defer func() {
			var released bool
			if alreadyCancelled {
				logger.Debug("No need to to release the lock since we are already cancelled")
				return
			}
			pingerCancel()
			for i := 0; i < 5; i++ {
				if err := tx.Raw("SELECT pg_advisory_unlock(?)", lockID).Scan(&released).Error; err != nil {
					logger.Error("Lock: error while trying to release", appLog.Int("LockID", lockID), appLog.Error(err))
					return
				}
				if released {
					logger.Debug("Lock: released", appLog.Int("LockID", lockID), appLog.Int("Attempt", i))
					return
				} else {
					logger.Debug("Lock: not released", appLog.Int("LockID", lockID), appLog.Int("Attempt", i))
					time.Sleep(2 * time.Second)
				}
			}
			logger.Error("Lock: unable to release", appLog.Int("LockID", lockID))
		}()

		lockContext := runHealthCheckLoop(pingerContext, tx, logger, lockID, &alreadyCancelled)
		return toRun(lockContext, tx)
	})
}
