package db

import (
	"database/sql"
	"fmt"
	"io"
	"math/rand"
	"net"
	"strconv"
	"time"

	"code.justin.tv/common/chitin/dbtrace"

	"golang.org/x/net/context"
)

// Result is the result of a database update, it is type aliased for consistency.
type Result sql.Result

type sqlResult sql.Result

// Rows is the result of a query.  It acts just like the stdlib's sql.Rows struct,
// but we wrap it here so we can emit events, release semaphores, etc.
type Rows interface {
	io.Closer

	Next() bool
	Err() error
	Scan(dest ...interface{}) error
	Columns() ([]string, error)
}

type rowsImpl struct {
	name        string
	evManager   *eventManager
	releaseFunc func()
	closed      bool

	rows sqlRows
}

var _ Rows = (*rowsImpl)(nil)

// Close acts just like the stdlib's implementation.
func (r *rowsImpl) Close() error {
	err := r.rows.Close()
	if !r.closed {
		r.releaseFunc()
		r.closed = true
	}
	return err
}

// Next acts just like the stdlib's implementation.
func (r *rowsImpl) Next() bool {
	startTime := time.Now()
	res := r.rows.Next()
	r.evManager.emitDBEvent("nextrow", r.name, time.Now().Sub(startTime))
	if !res {
		r.Close()
	}
	return res
}

// Err acts just like the stdlib's implementation.
func (r *rowsImpl) Err() error {
	return r.rows.Err()
}

// Scan acts just like the stdlib's implementation.
func (r *rowsImpl) Scan(dest ...interface{}) error {
	return r.rows.Scan(dest...)
}

// Columns acts just like the stdlib's implementation.
func (r *rowsImpl) Columns() ([]string, error) {
	return r.rows.Columns()
}

// Row is the result of a query on a single row.  It acts just like the stdlib's sql.Row struct,
// but we wrap it here so we can emit events, release semaphores, etc.
type Row interface {
	Scan(dest ...interface{}) error
}

type rowImpl struct {
	name        string
	evManager   *eventManager
	releaseFunc func()
	closed      bool

	row sqlRow
	err error
}

var _ Row = (*rowImpl)(nil)

func (r *rowImpl) doClose() {
	if !r.closed {
		r.releaseFunc()
		r.closed = true
	}
}

// Scan acts just like the stdlib's implementation.
func (r *rowImpl) Scan(dest ...interface{}) error {
	if !r.closed {
		defer r.doClose()
	}

	if r.err != nil {
		// No need to emit an error here, since it was already emitted when this field was set.
		// See the end of the QueryRow function in db.go.
		return r.err
	}

	start := time.Now()
	err := r.row.Scan(dest...)
	if err != nil {
		r.evManager.emitDBEvent("queryrow.scan.error", r.name, time.Now().Sub(start))
	} else {
		r.evManager.emitDBEvent("queryrow.scan.success", r.name, time.Now().Sub(start))
	}

	return err
}

func buildDB(conf *config) (sqlDB, error) {
	var connectionFmt string

	connectionProperties := []interface{}{
		conf.host,
		conf.port,
		conf.user,
		conf.dbName,
	}

	if conf.driverName == "mysql" {
		connectionFmt = "%[3]s:%[5]s@(%[1]s:%[2]d)/%[4]s"
		connectionProperties = append(connectionProperties, conf.password)
	} else {
		// Default to postgres connection format
		connectionFmt = "host=%s port=%d user=%s dbname=%s sslmode=disable binary_parameters=yes"
		if conf.password != "" {
			connectionFmt = connectionFmt + " password=%s"
			connectionProperties = append(connectionProperties, conf.password)
		}
	}

	connect := fmt.Sprintf(
		connectionFmt,
		connectionProperties...,
	)
	db, err := sql.Open(conf.driverName, connect)
	if err != nil {
		return nil, err
	}
	db.SetMaxOpenConns(conf.maxOpenConns)

	return &sqlDBImpl{DB: db}, nil
}

// jitter between low% and high% of duration
func jitter(d time.Duration, low, high float64) time.Duration {
	return time.Duration((float64(d) * low) + (float64(d) * (high - low) * rand.Float64()))
}

func dbHostport(conf *config) string {
	return net.JoinHostPort(conf.host, strconv.Itoa(conf.port))
}

// send the trace event corresponding to the start of a query.
func traceStart(ctx context.Context, conf *config, query string) {
	dbtrace.BasicRequestHeadPrepared(ctx, dbHostport(conf), conf.dbName, conf.user, query)
}

// send the trace event corresponding to the completion of a query.
func traceEnd(ctx context.Context) {
	dbtrace.BasicResponseHeadReceived(ctx)
}
