package datastore

import (
	"context"
	"database/sql"
	"strconv"
	"strings"
	"time"

	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/errors"
)

type key int

const (
	dbTxContextKey key = iota
)

const (
	driver = "postgres"
)

// DBConfig contains the configuration values for the database.
type DBConfig struct {
	hostname     *distconf.Str
	username     *distconf.Str
	password     *distconf.Str
	maxOpenConns *distconf.Int
}

// Load loads the config for the database.
func (c *DBConfig) Load(d *distconf.Distconf) error {
	c.hostname = d.Str("meepo_db.hostname", "")
	if c.hostname.Get() == "" {
		return errors.New("db hostname could not be loaded")
	}

	c.maxOpenConns = d.Int("meepo_db.max_open_connections", 0)
	if c.maxOpenConns.Get() == 0 {
		return errors.New("db max_open_connections could not be loaded")
	}
	return nil
}

// LoadSecrets loads the secret config for the database.
func (c *DBConfig) LoadSecrets(d *distconf.Distconf) error {
	c.username = d.Str("meepo_db.username", "")
	c.password = d.Str("meepo_db.password", "")

	if c.username.Get() == "" || c.password.Get() == "" {
		return errors.New("db username and password could not be loaded")
	}
	return nil
}

type queryExecutor interface {
	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}

type dbTx struct {
	tx         *sql.Tx
	committed  bool
	rolledBack bool
}

func (d *datastore) StartOrJoinTx(ctx context.Context, opts *sql.TxOptions) (context.Context, bool, error) {
	maybeDbTx := ctx.Value(dbTxContextKey)
	if _, ok := maybeDbTx.(*dbTx); ok {
		return ctx, false, nil
	}
	tx, err := d.BeginTx(ctx, opts)
	if err != nil {
		return nil, false, err
	}
	return context.WithValue(ctx, dbTxContextKey, &dbTx{tx: tx}), true, nil
}

func (d *datastore) getTxIfJoined(ctx context.Context) queryExecutor {
	maybeDbTx := ctx.Value(dbTxContextKey)
	if val, ok := maybeDbTx.(*dbTx); ok {
		return val.tx
	}
	return d
}

func (d *datastore) CommitTx(ctx context.Context, createdTx bool) error {
	if !createdTx {
		return nil
	}
	maybeDbTx := ctx.Value(dbTxContextKey)
	if val, ok := maybeDbTx.(*dbTx); ok {
		if val.committed {
			return errors.New("transaction already committed")
		}
		if val.rolledBack {
			return errors.New("transaction already rolled back")
		}
		val.committed = true
		return val.tx.Commit()
	}
	return errors.New("no transaction started")
}

func (d *datastore) RollbackTxIfNotCommitted(ctx context.Context, createdTx bool) {
	if !createdTx {
		return
	}
	maybeDbTx := ctx.Value(dbTxContextKey)
	if val, ok := maybeDbTx.(*dbTx); ok {
		if val.committed {
			return
		}
		if val.rolledBack {
			d.log.LogCtx(ctx, "err", errors.New("transaction already rolled back"), "error rolling back transaction")
		}
		val.rolledBack = true
		err := val.tx.Rollback()
		if err != nil {
			d.log.LogCtx(ctx, "err", errors.Wrap(err, "error rolling back transaction"))
		}
	}
	d.log.LogCtx(ctx, "err", errors.New("no transaction started"), "error rolling back transaction")
}

func (d *datastore) closeRows(rows *sql.Rows) {
	if rows == nil {
		return
	}
	err := rows.Close()
	if err != nil {
		d.log.Log("err", err, "error closing rows")
	}
}

func (d *datastore) recordStats(ctx context.Context, operationName string, startTime time.Time, succeeded bool) {
	// Record duration so that we can calculate latency.
	endTime := time.Now()
	duration := endTime.Sub(startTime)
	d.stats.TimingDurationC(operationName+".time", duration, 1)

	// Distinguish between the operation succeeding, failing due the context being canceled, and failing due to a db
	// error.
	status := "success"
	if !succeeded {
		if ctx.Err() != nil {
			status = "ctx_error"
		} else {
			status = "db_error"
		}
	}

	// Record count so that we can calculate throughput.
	d.stats.IncC(operationName+".status."+status, 1, 1)
}

func (d *datastore) now() time.Time {
	return ConvertToDBTime(d.clock.Now())
}

// ConvertToDBTime converts the given time to a time that can be reliably represented by Postgres.
func ConvertToDBTime(t time.Time) time.Time {
	// Ensure that we're using a UTC time to avoid issues with serializing/deserializing timezone info to
	// the database.
	t = t.UTC()

	// Unlike Go, Postgres date/time types only support a resolution of up to 1 microsecond.  Trim the given time
	// so we can work with a time that the database can represent.
	return t.Truncate(time.Microsecond)
}

func generatePlaceholders(start, num int) string {
	placeholders := make([]string, num)
	for i := 0; i < num; i++ {
		placeholders[i] = "$" + strconv.Itoa(start+i)
	}
	return strings.Join(placeholders, ",")
}
