package dbx

import (
	"context"
	"errors"
	"sync"

	"github.com/jmoiron/sqlx"
)

// MustBegin starts a new transaction in the context.
// Using the new context in other dbx query methods will execute them inside the transaction.
// The transaction must later be commited (Commit) or rolled-back (RollbackUnlessComitted).
// Usage example:
//
//     ctx = dbx.MustBegin(ctx, sqlxdb)
//     defer dbx.RollbackUnlessComitted(ctx, nil)
//     // exec queries in sqlxdb
//     dbx.Commit(ctx)
//
func MustBegin(ctx context.Context, db *sqlx.DB) context.Context {
	dbtx := db.MustBegin()
	tx := &Tx{dbtx: dbtx, active: true}
	return context.WithValue(ctx, txCtxKey, tx)
}

// RollbackUnlessComitted does rollback on the current transaction, unless it was already commited or rolled-back.
// If provided, the errHandler function is called if there's an error during rollback.
// This function is designed to be used on defer statements.
func RollbackUnlessComitted(ctx context.Context, errHandler func(err error)) {
	tx := getTx(ctx)
	if tx == nil {
		if errHandler != nil {
			errHandler(errors.New("dbx RollbackUnlessComitted: missing transaction in context"))
		}
		return
	}

	if !tx.active {
		return // already commited, skip ahead so it can be used in defer statements
	}

	tx.Lock()
	defer tx.Unlock()

	if !tx.active {
		return
	}

	defer func() {
		if p := recover(); p != nil {
			// try to Rollbac on panics
			_ = tx.dbtx.Rollback()
			tx.active = false
			panic(p) // re-throw the panic to keep regular panic handling
		}
	}()

	err := tx.dbtx.Rollback()
	if err != nil && errHandler != nil {
		errHandler(err)
	}
	tx.active = false // Rolled back
}

// Commit commits the current transaction in the context.
func Commit(ctx context.Context) error {
	tx := getTx(ctx)
	if tx == nil {
		return errors.New("dbx Commit: missing transaction in context")
	}

	tx.Lock()
	defer tx.Unlock()

	if !tx.active {
		return errors.New("dbx Commit: already commited or rolled back")
	}

	err := tx.dbtx.Commit()
	if err != nil {
		return err
	}
	tx.active = false // Commited
	return nil
}

var txCtxKey = new(int)

type Tx struct {
	sync.Mutex
	dbtx   *sqlx.Tx
	active bool
}

func getTx(ctx context.Context) *Tx {
	if tx, ok := ctx.Value(txCtxKey).(*Tx); ok {
		return tx
	}
	return nil
}

// GetActiveTx returns the current transaction in the context unless it was already commited or rolled-back.
// Transactions are assumed inactive after doing Commit or Rollback.
func GetActiveTx(ctx context.Context) *sqlx.Tx {
	if tx := getTx(ctx); tx != nil && tx.active {
		return tx.dbtx
	}
	return nil
}
