package db

import (
	"context"
	"database/sql"
	"fmt"
	"strconv"
	"strings"
	"time"

	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/errors"
	"code.justin.tv/feeds/log"

	// Register pq as the database driver.
	_ "github.com/lib/pq"
)

const (
	driver        = "postgres"
	inClauseLimit = 1000
)

type key int

const (
	dbTxContextKey key = iota
)

type Config struct {
	hostname     *distconf.Str
	username     *distconf.Str
	password     *distconf.Str
	maxOpenConns *distconf.Int
}

func (c *Config) Load(d *distconf.Distconf) error {
	c.hostname = d.Str("oracle_db.hostname", "")
	c.maxOpenConns = d.Int("oracle_db.max_open_connections", 0)

	return nil
}

func (c *Config) LoadSecrets(d *distconf.Distconf) error {
	c.username = d.Str("oracle_db.username", "")
	c.password = d.Str("oracle_db.password", "")

	return nil
}

type Impl struct {
	*sql.DB

	Log log.Logger
}

var _ DB = &Impl{}

type queryExecutor interface {
	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}

func NewClient(config *Config, logger log.Logger) (*Impl, error) {
	logger.Log("hostname", config.hostname.Get(), "Opening connection to Postgres")

	connectionString := fmt.Sprintf("postgres://%s:%s@%s", config.username.Get(), config.password.Get(), config.hostname.Get())
	db, err := sql.Open(driver, connectionString)
	if err != nil {
		return nil, err
	}

	db.SetMaxOpenConns(int(config.maxOpenConns.Get()))

	return &Impl{
		DB:  db,
		Log: logger,
	}, nil
}

type dbTx struct {
	tx         *sql.Tx
	committed  bool
	rolledBack bool
}

func (db *Impl) StartOrJoinTx(ctx context.Context, opts *sql.TxOptions) (context.Context, bool, error) {
	maybeDbTx := ctx.Value(dbTxContextKey)
	if _, ok := maybeDbTx.(*dbTx); ok {
		return ctx, false, nil
	}
	tx, err := db.BeginTx(ctx, opts)
	if err != nil {
		return nil, false, err
	}
	return context.WithValue(ctx, dbTxContextKey, &dbTx{tx: tx}), true, nil
}

func (db *Impl) getTxIfJoined(ctx context.Context) queryExecutor {
	maybeDbTx := ctx.Value(dbTxContextKey)
	if val, ok := maybeDbTx.(*dbTx); ok {
		return val.tx
	}
	return db
}

func (db *Impl) CommitTx(ctx context.Context, createdTx bool) error {
	if !createdTx {
		return nil
	}
	maybeDbTx := ctx.Value(dbTxContextKey)
	if val, ok := maybeDbTx.(*dbTx); ok {
		if val.committed {
			return errors.New("transaction already committed")
		}
		if val.rolledBack {
			return errors.New("transaction already rolled back")
		}
		val.committed = true
		return val.tx.Commit()
	}
	return errors.New("no transaction started")
}

func (db *Impl) RollbackTxIfNotCommitted(ctx context.Context, createdTx bool) {
	if !createdTx {
		return
	}
	maybeDbTx := ctx.Value(dbTxContextKey)
	if val, ok := maybeDbTx.(*dbTx); ok {
		if val.committed {
			return
		}
		if val.rolledBack {
			db.Log.Log("err", errors.New("transaction already rolled back"), "error rolling back transaction")
		}
		val.rolledBack = true
		err := val.tx.Rollback()
		if err != nil {
			db.Log.Log("err", errors.Wrap(err, "error rolling back transaction"))
		}
	}
	db.Log.Log("err", errors.New("no transaction started"), "error rolling back transaction")
}

func (db *Impl) closeRows(rows *sql.Rows) {
	if rows == nil {
		return
	}
	err := rows.Close()
	if err != nil {
		db.Log.Log("err", err, "error closing rows")
	}
}

func now() time.Time {
	return ConvertToDBTime(time.Now())
}

// ConvertToDBTime converts the given time to a time that can be reliably represented by Postgres.
func ConvertToDBTime(t time.Time) time.Time {
	// Ensure that we're using a UTC time to avoid issues with serializing/deserializing timezone info to
	// the database.
	t = t.UTC()

	// Unlike Go, Postgres date/time types only support a resolution of up to 1 microsecond.  Trim the given time
	// so we can work with a time that the database can represent.
	return t.Truncate(time.Microsecond)
}

func placeholder(num int) string {
	return "$" + strconv.Itoa(num)
}

func generatePlaceholders(start, num int) string {
	placeholders := make([]string, num)
	for i := 0; i < num; i++ {
		placeholders[i] = "$" + strconv.Itoa(start+i)
	}
	return strings.Join(placeholders, ",")
}

type TimeWindow struct {
	Start *time.Time
	End   *time.Time
}

func membershipTest(fieldName string, values []string, placeholderStart int) (string, []interface{}, error) {
	if len(values) == 0 {
		return "", nil, errors.Errorf("not enough values to do a membership test on field %s", fieldName)
	}
	if len(values) > inClauseLimit {
		return "", nil, errors.Errorf("too many arguments provided for IN on %s: %d", fieldName, len(values))
	}
	var statement string
	if len(values) == 1 {
		statement = fmt.Sprintf("%s = %s", fieldName, placeholder(placeholderStart))
	} else {
		statement = fmt.Sprintf("%s in (%s)", fieldName, generatePlaceholders(placeholderStart, len(values)))
	}

	vals := make([]interface{}, len(values))
	for i, s := range values {
		vals[i] = s
	}
	return statement, vals, nil
}

func timeWindowTest(fieldName string, window *TimeWindow, placeholderCounter int) (string, []interface{}) {
	if window == nil || (window.Start == nil && window.End == nil) {
		return "", nil
	}
	constraints := make([]string, 0, 2)
	values := make([]interface{}, 0, 2)
	if window.Start != nil {
		constraints = append(constraints, fmt.Sprintf("%s >= %s", fieldName, placeholder(placeholderCounter)))
		values = append(values, (*window.Start).UTC())
		placeholderCounter++
	}

	if window.End != nil {
		constraints = append(constraints, fmt.Sprintf("%s < %s", fieldName, placeholder(placeholderCounter)))
		values = append(values, (*window.End).UTC())
	}

	return strings.Join(constraints, " AND "), values
}
