package gosql

import (
	"context"
	"database/sql"
)

type BeginTxOption func(*context.Context, **sql.TxOptions)

func WithContext(ctx context.Context) BeginTxOption {
	return func(txCtx *context.Context, _ **sql.TxOptions) {
		*txCtx = ctx
	}
}

func WithTxOptions(opts *sql.TxOptions) BeginTxOption {
	return func(_ *context.Context, txOpts **sql.TxOptions) {
		*txOpts = opts
	}
}

// WithTx starts transaction with specified connection.
func WithTx(
	conn TxBeginner,
	fn func(tx *sql.Tx) error,
	options ...BeginTxOption,
) error {
	var ctx context.Context
	var opts *sql.TxOptions
	for _, option := range options {
		option(&ctx, &opts)
	}
	if ctx == nil {
		ctx = context.Background()
	}
	tx, err := conn.BeginTx(ctx, opts)
	if err != nil {
		return err
	}
	defer func() {
		if r := recover(); r != nil {
			_ = tx.Rollback()
			panic(r)
		}
	}()
	if err := fn(tx); err != nil {
		_ = tx.Rollback()
		return err
	}
	return tx.Commit()
}

// WithTxContext starts transaction with specified connection.
//
// Deprecated.
func WithTxContext(
	ctx context.Context, conn TxBeginner, opts *sql.TxOptions,
	fn func(tx *sql.Tx) error,
) error {
	return WithTx(conn, fn, WithContext(ctx), WithTxOptions(opts))
}
