package main

import (
	"database/sql"
	"github.com/goadesign/goa"
	"github.com/jmoiron/sqlx"
	"log"
	"sync"
	"time"
)

type DB struct {
	sync.Mutex
	conn *sqlx.DB
}

type Tx struct {
	*sqlx.Tx
	namespace string
}

// NewDB creates a new db connection, and dies if something goes wrong
func NewDB(connString string) *DB {
	conn, err := sqlx.Open("postgres", connString)
	if err != nil {
		log.Fatalf("Can't connect to db: %v", err)
	}
	db := &DB{conn: conn}
	return db
}

type dbMetricKey struct {
	namespace string
	queryName string
}

func (key dbMetricKey) buildKey(suffixes ...string) []string {
	// TODO: sanitize metric names
	return append([]string{"db", key.namespace, key.queryName}, suffixes...)
}

func (key dbMetricKey) timeAndCount(f func() error) {
	start := time.Now()
	err := f()
	goa.MeasureSince(key.buildKey(), start)
	if err != nil && err != sql.ErrNoRows && err != sql.ErrTxDone {
		goa.IncrCounter(key.buildKey("error"), 1.0)
	} else {
		goa.IncrCounter(key.buildKey("success"), 1.0)
	}
}

func (key dbMetricKey) time(f func()) {
	start := time.Now()
	f()
	goa.MeasureSince(key.buildKey(), start)
}

func (db *DB) Get(namespace, queryName string, dest interface{}, query string, args ...interface{}) (err error) {
	key := dbMetricKey{namespace, queryName}
	key.timeAndCount(func() error {
		err = db.conn.Get(dest, query, args...)
		return err
	})
	return
}

func (db *DB) QueryRowx(namespace, queryName, query string, args ...interface{}) (row *sqlx.Row) {
	key := dbMetricKey{namespace, queryName}
	key.time(func() {
		row = db.conn.QueryRowx(query, args...)
	})
	return
}

func (db *DB) Queryx(namespace, queryName, query string, args ...interface{}) (rows *sqlx.Rows, err error) {
	key := dbMetricKey{namespace, queryName}
	key.timeAndCount(func() error {
		rows, err = db.conn.Queryx(query, args...)
		return err
	})
	return
}

func (db *DB) Select(namespace, queryName string, dest interface{}, query string, args ...interface{}) (err error) {
	key := dbMetricKey{namespace, queryName}
	key.timeAndCount(func() error {
		err = db.conn.Select(dest, query, args...)
		return err
	})
	return
}

// GetNewTx gets a new postgres transaction from the database connection
func (db *DB) GetNewTx(namespace string) (tx *Tx, err error) {
	var sqlTx *sqlx.Tx
	key := dbMetricKey{namespace, "startTransaction"}
	key.timeAndCount(func() error {
		sqlTx, err = db.conn.Beginx()
		return err
	})
	if err != nil {
		return
	}
	tx = &Tx{sqlTx, namespace}
	return
}

// Commit commits a transaction
func (tx *Tx) Commit() (err error) {
	key := dbMetricKey{tx.namespace, "commitTransaction"}
	key.timeAndCount(func() error {
		err = tx.Tx.Commit()
		return err
	})
	return
}

// Rollback rolls back a transaction
func (tx *Tx) Rollback(err error) {
	if err == nil { // only rollback on errors
		return
	}
	key := dbMetricKey{tx.namespace, "rollbackTransaction"}
	key.timeAndCount(func() error {
		err = tx.Tx.Rollback()
		return err
	})
}

func (tx *Tx) Exec(queryName, query string, args ...interface{}) (res sql.Result, err error) {
	key := dbMetricKey{tx.namespace, queryName}
	key.timeAndCount(func() error {
		res, err = tx.Tx.Exec(query, args...)
		return err
	})
	return
}

func (tx *Tx) QueryRow(queryName, query string, args ...interface{}) (row *sql.Row) {
	key := dbMetricKey{tx.namespace, queryName}
	key.time(func() {
		row = tx.Tx.QueryRow(query, args...)
	})
	return
}

func (tx *Tx) Get(queryName string, dest interface{}, query string, args ...interface{}) (err error) {
	key := dbMetricKey{tx.namespace, queryName}
	key.timeAndCount(func() error {
		err = tx.Tx.Get(dest, query, args...)
		return err
	})
	return
}
