package db

import (
	"database/sql"
	"errors"
	"io"
	"log"
	"math/rand"
	"sync"
	"time"

	"golang.org/x/net/context"

	"code.justin.tv/chat/semaphore"
	"code.justin.tv/common/chitin"
	"code.justin.tv/foundation/xray"
)

var (
	// ErrConnPoolExhausted is returned when the db client is unable to acquire
	// a backend connection within connAcquireTimeout milliseconds.
	ErrConnPoolExhausted = errors.New("db: connection pool exhausted")

	// ErrRequestTimeout is returned when the db client is unable to complete
	// the specified operation within requestTimeout milliseconds.
	ErrRequestTimeout = errors.New("db: request timed out")

	// ErrNoRows is returned when the db client returns no rows for a query.
	ErrNoRows = sql.ErrNoRows

	// ErrTxDone is returned whenever a transaction is "closed" twice - this happens most commonly
	// on a context cancel.
	ErrTxDone = errors.New("db: Transaction has already been committed or rolled back")

	noReleaseConnFn = func() {}
)

type config struct {
	driverName string

	host     string
	port     int
	user     string
	password string
	dbName   string

	maxOpenConns int
	maxIdleConns int
	maxQueueSize int

	connAcquireTimeout time.Duration
	requestTimeout     time.Duration
	maxConnAge         time.Duration

	disableBinaryParams bool
	sslEnabled          bool
}

// DBInfo is a struct representing the state of the db client at some point in time.
type DBInfo struct {
	// OpenConnsCap describes the size of the pool.
	OpenConnsCap int

	// MaxOpenConns describes the maximum number of concurrent connections which were open at some point in a duration.
	MaxOpenConns int

	// MinAvailableConns describes the minimum number of available connections which were available at some point in a duration.
	MinAvailableConns int

	// SinceLastConnRecycle describes how much time has elapsed since the last connection recycle event.
	SinceLastConnRecycle time.Duration

	// SinceLastConfigReload describes how much time has elapsed since the last configuration reload.
	SinceLastConfigReload time.Duration
}

// sqlDB abstracts a sql.DB, mostly for the purposes of testability.
type sqlDB interface {
	io.Closer
	Query(ctx context.Context, query string, args ...interface{}) (sqlRows, error)
	QueryRow(ctx context.Context, query string, args ...interface{}) sqlRow
	Exec(ctx context.Context, query string, args ...interface{}) (sqlResult, error)
	SetMaxIdleConns(int)
	Begin(ctx context.Context, opts *sql.TxOptions) (sqlTx, error)
}

type sqlDBImpl struct {
	*xray.DB
}

func (s *sqlDBImpl) Query(ctx context.Context, query string, args ...interface{}) (sqlRows, error) {
	return s.DB.Query(ctx, query, args...)
}

func (s *sqlDBImpl) QueryRow(ctx context.Context, query string, args ...interface{}) sqlRow {
	return s.DB.QueryRow(ctx, query, args...)
}

func (s *sqlDBImpl) Exec(ctx context.Context, query string, args ...interface{}) (sqlResult, error) {
	return s.DB.Exec(ctx, query, args...)
}

type sqlRows interface {
	Close() error
	Next() bool
	Err() error
	Scan(...interface{}) error
	Columns() ([]string, error)
}

type sqlRow interface {
	Scan(...interface{}) error
}

func (db *sqlDBImpl) Begin(ctx context.Context, opts *sql.TxOptions) (sqlTx, error) {
	tx, err := db.DB.Begin(ctx, opts)
	return &sqlTxImpl{tx: tx}, err
}

// The DB interface specifies how to interact with a db client.  dbImpl
// is the canonical implementation.
type DB interface {
	io.Closer

	// Reload resets the database's options in a concurrency-safe way.
	Reload(opts ...Option) error

	// Recycle forces a recycle of the database's connections in a concurrency-safe way.
	Recycle() error

	// Begin starts a transaction.
	Begin(ctx context.Context, name string) (Tx, error)

	// Query executes a query that returns rows, typically a SELECT.  The name is any
	// client-specified tag, and the args are for any placeholder parameters in the query
	// (as in the stdlib package)
	Query(ctx context.Context, name, query string, args ...interface{}) (Rows, error)

	// QueryRow executes a query that returns one row, typically a SELECT.  The name is any
	// client-specified tag, and the args are for any placeholder parameters in the query
	// (as in the stdlib package). QueryRow always return a non-nil value.
	// Errors are deferred until Row's Scan method is called.
	QueryRow(ctx context.Context, name, query string, args ...interface{}) Row

	// Exec executes a query without returning any rows, typically an UPDATE or INSERT.
	// The name is any client-specified tag, and the args are for any placeholder parameters in the query
	// (as in the stdlib package)
	Exec(ctx context.Context, name, query string, args ...interface{}) (Result, error)

	// Info returns a DBInfo containing the database's current state in a
	// concurarency-safe way.
	Info() DBInfo

	// SetCallbacks sets the database's callbacks in a concurrency-safe way.
	// Either callback may be set to nil.
	SetCallbacks(dbCb DBCallback, runCb RunCallback)
}

type dbState struct {
	*config
	db sqlDB

	lastConnRecycle  time.Time
	nextConnRecycle  time.Time
	lastConfigReload time.Time
}

type recycleState struct {
	recycleDuration time.Duration
	recycleStart    time.Time
	oldState        *dbState
}

func (s *recycleState) shouldUseOld() bool {
	// Returns new pool with probability (time_spent / recycle_time),
	// so as time increases we're more likely to use the new pool.
	elapsed := time.Now().Sub(s.recycleStart)
	return (rand.Float64() >= float64(elapsed)/float64(s.recycleDuration))
}

func (s *recycleState) cleanup() {
	// avoid races to use the old state with a closed db conn
	d := s.oldState.config.requestTimeout + s.oldState.config.connAcquireTimeout
	t := time.NewTimer(d)
	go func() {
		defer t.Stop()
		<-t.C
		s.oldState.db.Close()
	}()
}

type eventManager struct {
	dbCb  DBCallback
	runCb RunCallback

	rwMu sync.RWMutex
}

func (em *eventManager) setCbs(dbCb DBCallback, runCb RunCallback) {
	em.rwMu.Lock()
	defer em.rwMu.Unlock()
	em.dbCb = dbCb
	em.runCb = runCb
}

func (em *eventManager) emitDBEvent(evName, queryName string, d time.Duration) {
	em.rwMu.RLock()
	defer em.rwMu.RUnlock()
	if hook := em.dbCb; hook != nil {
		hook(evName, queryName, d)
	}
}

func (em *eventManager) emitRunEvent(evName string) {
	em.rwMu.RLock()
	defer em.rwMu.RUnlock()
	if hook := em.runCb; hook != nil {
		hook(evName)
	}
}

type dbImpl struct {
	state     *dbState
	sem       semaphore.Semaphore
	reload    chan empty
	done      chan empty
	rwMu      sync.RWMutex
	evManager *eventManager
	recycler  *recycleState
}

// Open opens a database specified by its options
// (see below for an exhaustive list of options).
func Open(opts ...Option) (DB, error) {
	db := &dbImpl{
		state: &dbState{
			config: &config{},
		},
		reload:    make(chan empty),
		done:      make(chan empty),
		evManager: &eventManager{},
	}

	err := db.Reload(opts...)
	if err != nil {
		return nil, err
	}
	go db.manage()
	return db, nil
}

func (db *dbImpl) Reload(opts ...Option) error {
	db.rwMu.Lock()
	defer db.rwMu.Unlock()

	newReloadCh := make(chan empty, 0)
	close(db.reload)
	db.reload = newReloadCh

	prevState := db.state
	db.state = &dbState{
		config:           &config{},
		lastConfigReload: time.Now(),
	}
	for _, opt := range opts {
		opt(db)
	}

	if db.sem == nil {
		db.sem = semaphore.New(db.state.config.maxOpenConns, db.state.config.maxQueueSize)
	} else {
		db.sem.Resize(db.state.config.maxOpenConns)
	}

	err := db.recycleDBLocked(prevState)
	if err != nil {
		return err
	}
	return nil
}

func (db *dbImpl) Recycle() error {
	db.rwMu.Lock()
	defer db.rwMu.Unlock()
	prevState := db.state
	err := db.recycleDBLocked(prevState)
	if err != nil {
		return err
	}
	return nil
}

func (db *dbImpl) Close() error {
	close(db.done)

	db.rwMu.Lock()
	defer db.rwMu.Unlock()

	return db.state.db.Close()
}

func (db *dbImpl) Info() DBInfo {
	semSize := db.sem.Size()
	semMinAvailable := db.sem.MinAvailable()

	db.rwMu.RLock()
	defer db.rwMu.RUnlock()

	return DBInfo{
		OpenConnsCap:      semSize,
		MaxOpenConns:      semSize - semMinAvailable,
		MinAvailableConns: semMinAvailable,

		SinceLastConnRecycle:  time.Since(db.state.lastConnRecycle),
		SinceLastConfigReload: time.Since(db.state.lastConfigReload),
	}
}

func (db *dbImpl) Begin(ctx context.Context, name string) (Tx, error) {
	state := db.getBackend()

	semCtx, _ := context.WithTimeout(ctx, state.config.connAcquireTimeout)
	acquired, duration := db.sem.AcquireContext(semCtx)
	if !acquired {
		db.evManager.emitDBEvent("acquire.fail", name, duration)

		err := ctx.Err()
		if err == nil {
			err = ErrConnPoolExhausted
		}
		return nil, err
	}
	db.evManager.emitDBEvent("acquire.success", name, duration)

	sqlTx, err := state.db.Begin(ctx, nil)
	if err != nil {
		db.sem.Release()
		return nil, err
	}

	tx := &txImpl{
		db:    db,
		done:  make(chan empty, 0),
		state: state,
		tx:    sqlTx,
	}

	go func() {
		//we do not handle ctx.Done() here
		//because it leads to committing on
		//timeout regardless of if the
		//application transaction is complete.
		//During a multistatement transaction
		//if the context timeout timesout after
		//the first statement but before the
		//last statement the second statement
		//is then considered not apart of
		//the transaction and will commit
		//the second half of the statement.
		//When a db has the autocommit setting
		//enabled, this will happen.
		for {
			select {
			case <-tx.done:
				// This only happens when Commit() or Rollback() are called.
				db.sem.Release()
				return
			}
		}
	}()

	return tx, nil
}

func (db *dbImpl) Query(ctx context.Context, name, query string, args ...interface{}) (Rows, error) {
	state := db.getBackend()

	semCtx, _ := context.WithTimeout(ctx, state.config.connAcquireTimeout)
	acquired, duration := db.sem.AcquireContext(semCtx)
	if !acquired {
		db.evManager.emitDBEvent("acquire.fail", name, duration)

		err := ctx.Err()
		if err == nil {
			err = ErrConnPoolExhausted
		}
		return nil, err
	}
	db.evManager.emitDBEvent("acquire.success", name, duration)

	ctx = chitin.ExperimentalNewSpan(ctx)
	queryFn := func() (sqlRows, error) {
		traceStart(ctx, state.config, query)
		defer traceEnd(ctx)
		return state.db.Query(ctx, query, args...)
	}
	releaseConnFn := func() {
		db.sem.Release()
	}
	return db.query(ctx, state, name, queryFn, releaseConnFn)
}

func (db *dbImpl) query(ctx context.Context, state *dbState, name string, queryFn func() (sqlRows, error), releaseConnFn func()) (Rows, error) {
	startTime := time.Now()
	response := make(chan *asyncRows, 0)
	done := make(chan empty, 0)
	defer close(done)
	go func() {
		dbRows, err := queryFn()
		var rows *rowsImpl
		if err != nil {
			releaseConnFn()
		} else {
			rows = &rowsImpl{
				releaseFunc: releaseConnFn,
				rows:        dbRows,
				name:        name,
				evManager:   db.evManager,
			}
		}

		select {
		case response <- &asyncRows{rows, err}:
			// The caller is expected to call rows.Close() to avoid leaking conns
		case <-done:
			// If we did not pass a response to the caller, we need to close the rows object
			// so we don't leak connections
			if err == nil {
				rows.Close()
			}
		}
	}()

	timeout := time.NewTimer(state.config.requestTimeout)
	defer timeout.Stop()

	select {
	case async := <-response:
		if async.err != nil {
			db.evManager.emitDBEvent("query.error", name, time.Now().Sub(startTime))
			return nil, async.err
		}
		db.evManager.emitDBEvent("query.success", name, time.Now().Sub(startTime))
		return async.rows, nil
	case <-timeout.C:
		db.evManager.emitDBEvent("query.timeout", name, time.Now().Sub(startTime))
		return nil, ErrRequestTimeout
	case <-ctx.Done():
		db.evManager.emitDBEvent("query.cancelled", name, time.Now().Sub(startTime))
		// We don't track failures here because this does not indicate an error with this query
		return nil, ctx.Err()
	}
}

func (db *dbImpl) QueryRow(ctx context.Context, name, query string, args ...interface{}) Row {
	state := db.getBackend()

	semCtx, _ := context.WithTimeout(ctx, state.config.connAcquireTimeout)
	acquired, duration := db.sem.AcquireContext(semCtx)
	if !acquired {
		db.evManager.emitDBEvent("acquire.fail", name, duration)

		err := ctx.Err()
		if err == nil {
			err = ErrConnPoolExhausted
		}
		return &rowImpl{err: err, releaseFunc: noReleaseConnFn}
	}
	db.evManager.emitDBEvent("acquire.success", name, duration)

	ctx = chitin.ExperimentalNewSpan(ctx)
	queryFn := func() sqlRow {
		traceStart(ctx, state.config, query)
		defer traceEnd(ctx)
		return state.db.QueryRow(ctx, query, args...)
	}
	releaseConnFn := func() {
		db.sem.Release()
	}
	return db.queryRow(ctx, state, name, queryFn, releaseConnFn)
}

func (db *dbImpl) queryRow(ctx context.Context, state *dbState, name string, queryFn func() sqlRow, releaseConnFn func()) Row {
	startTime := time.Now()
	response := make(chan sqlRow, 0)
	done := make(chan empty, 0)
	defer close(done)
	go func() {
		dbRow := queryFn()

		select {
		case response <- dbRow:
		case <-done:
			// Scan and close the row if we cancel
			if dbRow != nil {
				var discard interface{}
				dbRow.Scan(&discard)
				releaseConnFn()
			}
		}
	}()

	timeout := time.NewTimer(state.config.requestTimeout)
	defer timeout.Stop()

	// If we timeout here, we don't want the client to release the semaphore by calling
	// Scan() (which triggers a release).  Instead we want to hold the resource until
	// it actually finishes executing the operation.
	select {
	case async := <-response:
		db.evManager.emitDBEvent("queryrow.success", name, time.Now().Sub(startTime))
		return &rowImpl{
			row:         async,
			name:        name,
			evManager:   db.evManager,
			releaseFunc: releaseConnFn,
		}
	case <-timeout.C:
		db.evManager.emitDBEvent("queryrow.timeout", name, time.Now().Sub(startTime))
		return &rowImpl{err: ErrRequestTimeout, releaseFunc: noReleaseConnFn}
	case <-ctx.Done():
		db.evManager.emitDBEvent("queryrow.cancelled", name, time.Now().Sub(startTime))
		// We don't track failures here because this does not indicate an error with this query
		return &rowImpl{err: ctx.Err(), releaseFunc: noReleaseConnFn}
	}
}

func (db *dbImpl) Exec(ctx context.Context, name, query string, args ...interface{}) (Result, error) {
	state := db.getBackend()

	semCtx, _ := context.WithTimeout(ctx, state.config.connAcquireTimeout)
	acquired, duration := db.sem.AcquireContext(semCtx)
	if !acquired {
		db.evManager.emitDBEvent("acquire.fail", name, duration)

		err := ctx.Err()
		if err == nil {
			err = ErrConnPoolExhausted
		}
		return nil, err
	}
	db.evManager.emitDBEvent("acquire.success", name, duration)

	ctx = chitin.ExperimentalNewSpan(ctx)
	execFn := func() (sqlResult, error) {
		traceStart(ctx, state.config, query)
		defer traceEnd(ctx)
		return state.db.Exec(ctx, query, args...)
	}
	releaseConnFn := func() {
		db.sem.Release()
	}
	return db.exec(ctx, state, name, execFn, releaseConnFn)
}

func (db *dbImpl) exec(ctx context.Context, state *dbState, name string, execFn func() (sqlResult, error), releaseConnFn func()) (Result, error) {
	startTime := time.Now()
	response := make(chan *asyncResult, 0)
	done := make(chan empty, 0)
	defer close(done)
	go func() {
		dbRes, err := execFn()
		res := &asyncResult{
			result: dbRes,
			err:    err,
		}
		// No closing to be done, simply release the semaphore once the query returns
		releaseConnFn()

		select {
		case response <- res:
		case <-done:
		}
	}()

	timeout := time.NewTimer(state.config.requestTimeout)
	defer timeout.Stop()

	select {
	case asyncRes := <-response:
		if asyncRes.err != nil {
			db.evManager.emitDBEvent("exec.error", name, time.Now().Sub(startTime))
			return nil, asyncRes.err
		}
		db.evManager.emitDBEvent("exec.success", name, time.Now().Sub(startTime))
		return asyncRes.result, asyncRes.err
	case <-timeout.C:
		db.evManager.emitDBEvent("exec.timeout", name, time.Now().Sub(startTime))
		return nil, ErrRequestTimeout
	case <-ctx.Done():
		db.evManager.emitDBEvent("exec.cancelled", name, time.Now().Sub(startTime))
		// We don't track failures here because this does not indicate an error with this query
		return nil, ctx.Err()
	}
}

func (db *dbImpl) SetCallbacks(dbCb DBCallback, runCb RunCallback) {
	db.evManager.setCbs(dbCb, runCb)
}

func (db *dbImpl) manage() {
	for {
		db.rwMu.RLock()
		state := db.state
		reload := db.reload
		db.rwMu.RUnlock()

		select {
		case <-time.After(state.nextConnRecycle.Sub(time.Now())):
			db.onRecycle()
		case <-reload:
			// After a reload we want to restart the recycle timer
			continue
		case <-db.done:
			return
		}
	}
}

func (db *dbImpl) onRecycle() {
	db.rwMu.Lock()
	defer db.rwMu.Unlock()

	if db.recycler != nil {
		log.Printf("Already recycling db [%s]\n", db.state.config.dbName)
		db.scheduleRecycleLocked()
		return
	}
	prevState := db.state
	db.state = &dbState{
		config:           prevState.config,
		lastConfigReload: prevState.lastConfigReload,
	}
	err := db.recycleDBLocked(prevState)
	if err != nil {
		log.Printf("Error managing db [%s]: %v\n", db.state.config.dbName, err)
		db.evManager.emitRunEvent("recycle.error")
		db.state = prevState
		return
	}
}

func (db *dbImpl) scheduleRecycleLocked() {
	curTime := time.Now()
	db.state.lastConnRecycle = curTime
	// random duration between 70% and 80% of maxConnAge (80% max to account for the actual recycle time
	db.state.nextConnRecycle = curTime.Add(jitter(db.state.maxConnAge, 0.7, 0.8))
}

// Recycles happen over a period of time.  During this period, the number
// of idle connections between the new and old pool are rebalanced, with
// the sum of idle connections being invariant according to the db config.
//
// Without this logic, we would either double our actual number of idle connections
// over the maximum or disable idle connections for the old pool, which would mean
// we'd be creating a new connection for each request to the old pool.
//
// All db operations go to either the new or old pool, as determined by
// a probability distribution which increasingly favors the new pool as
// more time elapses.  This ensures a gradual change of db query distribution
// in case the application wants to change db backends, ensuring the new backend
// doesn't suddenly get hammered all at once.

// This function ensures that only one recycle event
// is happening at any given time.
func (db *dbImpl) recycleDBLocked(old *dbState) error {
	// Note: This function should only be invoked by a caller which has
	// a lock on the dbImpl (currently Reload() and onRecycle())

	curTime := time.Now()
	db.scheduleRecycleLocked()
	if db.recycler != nil {
		// TODO: what if we try to Reload() during a recycle?  Then we would
		// delay "finishing" the reload up to maxConnAge time.  Maybe have
		// reload/recycle use a channel of Configs to process in serial?
		log.Printf("Already recycling db [%s]\n", old.config.dbName)
		return nil
	}
	newDB, err := buildDB(db.state.config)
	if err != nil {
		return err
	}
	// random duration between 5% and 10% of maxConnAge
	recycleDuration := jitter(db.state.config.maxConnAge, 0.05, 0.1)
	db.state.db = newDB
	// No old pool to recycle if this is during DB object creation.
	if old.db != nil {
		// Recycler will be cleaned up at the end of the recycle period which will cleanup
		// the old db pool.  It belongs to db instead of db.state to persist across
		// many reload/recycle events.
		db.recycler = &recycleState{
			oldState:        old,
			recycleDuration: recycleDuration,
			recycleStart:    curTime,
		}
		// The new pool starts off with no idle connections, and gets more over the
		// recycle period (see dbImpl.rebalanceIdleConns()).
		newDB.SetMaxIdleConns(0)
		db.recycleBackendLocked()
	} else {
		newDB.SetMaxIdleConns(db.state.config.maxIdleConns)
	}
	return nil
}

func (db *dbImpl) recycleBackendLocked() {
	recycler := db.recycler
	state := db.state
	rebalanceInterval := time.Duration(float64(db.recycler.recycleDuration) / float64(db.state.config.maxIdleConns+1))
	rebalanceTicker := time.NewTicker(rebalanceInterval)

	newIdleConns := 0
	oldIdleConns := db.state.config.maxIdleConns

	go func() {
		defer rebalanceTicker.Stop()
		for {
			select {
			case <-rebalanceTicker.C:
				if oldIdleConns == 0 {
					db.finishRecycle()
					return
				}
				newIdleConns += 1
				oldIdleConns -= 1
				// SetMaxIdleConns is threadsafe.  We also know state and recycler are boths still safe to use
				// here because they'll only be written to during a recycle, and this is the only running recycle.
				state.db.SetMaxIdleConns(newIdleConns)
				recycler.oldState.db.SetMaxIdleConns(oldIdleConns)
			}
		}
	}()
}

func (db *dbImpl) finishRecycle() {
	db.rwMu.Lock()
	defer db.rwMu.Unlock()

	db.state.db.SetMaxIdleConns(db.state.config.maxIdleConns)

	db.recycler.cleanup()
	db.recycler = nil
}

func (db *dbImpl) getBackend() *dbState {
	db.rwMu.RLock()
	defer db.rwMu.RUnlock()

	// TODO: remove logs after darklaunch testing.
	if db.recycler == nil {
		return db.state
	} else if db.recycler.shouldUseOld() {
		return db.recycler.oldState
	}
	return db.state
}

type asyncRows struct {
	rows Rows
	err  error
}

type asyncResult struct {
	result Result
	err    error
}

type empty struct{}

// sqlTx abstracts a sql.Tx, mostly for the purposes of testability.
type sqlTx interface {
	Commit() error
	Query(ctx context.Context, query string, args ...interface{}) (sqlRows, error)
	Exec(ctx context.Context, query string, args ...interface{}) (sqlResult, error)
	QueryRow(ctx context.Context, query string, args ...interface{}) sqlRow
	Rollback() error
}

type sqlTxImpl struct {
	tx *xray.Tx
}

func (tx *sqlTxImpl) Commit() error {
	return tx.tx.Commit()
}

func (tx *sqlTxImpl) Query(ctx context.Context, query string, args ...interface{}) (sqlRows, error) {
	return tx.tx.Query(ctx, query, args...)
}

func (tx *sqlTxImpl) Exec(ctx context.Context, query string, args ...interface{}) (sqlResult, error) {
	return tx.tx.Exec(ctx, query, args...)
}

func (tx *sqlTxImpl) QueryRow(ctx context.Context, query string, args ...interface{}) sqlRow {
	return tx.tx.QueryRow(ctx, query, args...)
}

func (tx *sqlTxImpl) Rollback() error {
	return tx.tx.Rollback()
}

type Tx interface {
	// Commit commits the transaction
	Commit() error

	// Query executes a query that returns rows, typically a SELECT.  The name is any
	// client-specified tag, and the args are for any placeholder parameters in the query
	// (as in the stdlib package)
	Query(ctx context.Context, name, query string, args ...interface{}) (Rows, error)

	// Exec executes a query without returning any rows, typically an UPDATE or INSERT.
	// The name is any client-specified tag, and the args are for any placeholder parameters in the query
	// (as in the stdlib package)
	Exec(ctx context.Context, name, query string, args ...interface{}) (Result, error)

	// QueryRow executes a query that returns one row, typically a SELECT.  The name is any
	// client-specified tag, and the args are for any placeholder parameters in the query
	// (as in the stdlib package). QueryRow always return a non-nil value.
	// Errors are deferred until Row's Scan method is called.
	QueryRow(ctx context.Context, name string, query string, args ...interface{}) Row

	// Rollback aborts the transaction.
	Rollback() error
}

type txImpl struct {
	db    *dbImpl
	done  chan empty
	state *dbState
	tx    sqlTx

	isDoneMu sync.Mutex
	isDone   bool
}

func (tx *txImpl) Commit() error {
	tx.isDoneMu.Lock()
	defer tx.isDoneMu.Unlock()
	if tx.isDone {
		return ErrTxDone
	}
	tx.isDone = true
	defer close(tx.done)
	return tx.tx.Commit()
}

func (tx *txImpl) Rollback() error {
	tx.isDoneMu.Lock()
	defer tx.isDoneMu.Unlock()
	if tx.isDone {
		return ErrTxDone
	}
	tx.isDone = true
	defer close(tx.done)
	return tx.tx.Rollback()
}

func (tx *txImpl) Query(ctx context.Context, name, query string, args ...interface{}) (Rows, error) {
	ctx = chitin.ExperimentalNewSpan(ctx)
	queryFn := func() (sqlRows, error) {
		traceStart(ctx, tx.state.config, query)
		defer traceEnd(ctx)
		return tx.tx.Query(ctx, query, args...)
	}
	releaseConnFn := func() {}
	return tx.db.query(ctx, tx.state, name, queryFn, releaseConnFn)
}

func (tx *txImpl) Exec(ctx context.Context, name, query string, args ...interface{}) (Result, error) {
	ctx = chitin.ExperimentalNewSpan(ctx)
	execFn := func() (sqlResult, error) {
		traceStart(ctx, tx.state.config, query)
		defer traceEnd(ctx)
		return tx.tx.Exec(ctx, query, args...)
	}
	releaseConnFn := func() {}
	return tx.db.exec(ctx, tx.state, name, execFn, releaseConnFn)
}

func (tx *txImpl) QueryRow(ctx context.Context, name string, query string, args ...interface{}) Row {
	ctx = chitin.ExperimentalNewSpan(ctx)
	queryFn := func() sqlRow {
		traceStart(ctx, tx.state.config, query)
		defer traceEnd(ctx)
		return tx.tx.QueryRow(ctx, query, args...)
	}
	releaseConnFn := func() {}
	return tx.db.queryRow(ctx, tx.state, name, queryFn, releaseConnFn)
}
