package common

import (
	"context"
	"database/sql"
	"fmt"
	"reflect"
	"regexp"
	"strings"
	"time"

	"code.justin.tv/chat/golibs/errx"
	"code.justin.tv/chat/golibs/logx"
	"code.justin.tv/devrel/devsite-rbac/internal/errorutil"
	"github.com/Masterminds/squirrel"
	uuid "github.com/satori/go.uuid"
)

// Squirrel query builder configured for Postgress parameter bindings.
// You write bindings with "?" and they result in "$1" for Postgress.
var PSQL = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)

// NewUUID generates a new UUID as a string
func NewUUID() string {
	return uuid.NewV4().String()
}

// TimeNowStr generates a new string timestamp for current UTC time
func TimeNowStr() string {
	return time.Now().UTC().Format("2006-01-02 15:04")
}

// Paginate can be used to build paginated queries using a sensible limit default.
func Paginate(q squirrel.SelectBuilder, limit, offset uint64) squirrel.SelectBuilder {
	if limit == 0 {
		limit = 20 // default page size
	}
	return q.Limit(limit).Offset(offset)
}

// TimingStats is a function to track DB timing stats on methods wrapped with errxer --timings
func TimingStats(stats timingStatter) func(context.Context, time.Duration, string, error) {
	return func(ctx context.Context, d time.Duration, method string, err error) {
		status := "success"
		if errorutil.IsErrNoRows(err) {
			status = "norows"
		} else if err != nil {
			status = "error"
		}

		statName := fmt.Sprintf("db.%s.%s", method, status)
		statsErr := stats.TimingDuration(statName, d, 1.0)
		if statsErr != nil {
			logx.Error(ctx, statsErr, logx.Fields{"stat": statName, "db_method": method})
		}
	}
}

type timingStatter interface {
	TimingDuration(metric string, val time.Duration, rate float32) error
}

// CountOverAs defines a count window function over the whole dataset as an extra field to every row.
// It can be used as an extra column for paginated queries, for example `db:"_total"`.
func CountOverAs(dbField string) string {
	return "count(*) OVER() as " + dbField
}

// FirstRowDBField finds the first element, gets the value of the field with the given `db` tag, and assigns it to target.
func FirstRowDBField(list interface{}, dbField string, target interface{}) error {
	v := reflect.ValueOf(list)
	if v.Kind() == reflect.Ptr {
		v = v.Elem()
	}
	if !(v.Kind() == reflect.Slice || v.Kind() == reflect.Array) {
		return errx.New("must be slice or array")
	}
	if v.Len() == 0 {
		return nil
	}
	item := v.Index(0)
	ev := reflect.ValueOf(target)
	if ev.Kind() != reflect.Ptr {
		return errx.New("assignment interface must be pointer")
	}
	ev = ev.Elem()
	return assignDBField(item, ev, dbField)
}

func FirstRowInt32DBField(list interface{}, dbField string) int32 {
	var value int32
	err := FirstRowDBField(list, dbField, &value)
	if err != nil {
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		logx.Error(ctx, errx.Wrap(err, "FirstRowInt32DBField failure"))
		return 0
	}
	return value
}

func FirstRowInt64DBField(list interface{}, dbField string) int64 {
	var value int64
	err := FirstRowDBField(list, dbField, &value)
	if err != nil {
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		logx.Error(ctx, errx.Wrap(err, "FirstRowInt64DBField failure"))
		return 0
	}
	return value
}

func FirstRowUInt64DBField(list interface{}, dbField string) uint64 {
	var value uint64
	err := FirstRowDBField(list, dbField, &value)
	if err != nil {
		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		logx.Error(ctx, errx.Wrap(err, "FirstRowUInt64DBField failure"))
		return 0
	}
	return value
}

func assignDBField(v, ev reflect.Value, dbField string) error {
	if v.Kind() == reflect.Ptr {
		v = v.Elem()
	}
	if v.Kind() != reflect.Struct {
		return errx.New("slice items must be a struct")
	}
	t := v.Type()

	// look for the int field with db tag
	for i := 0; i < v.NumField(); i++ {
		tagVal := t.Field(i).Tag.Get("db")
		if tagVal != dbField {
			continue
		}
		if ev.Kind() != reflect.Interface && v.Field(i).Kind() != ev.Kind() {
			return errx.New(fmt.Sprintf("wanted %s, got %s", v.Field(i).Kind(), ev.Kind()))
		}
		ev.Set(v.Field(i))
		return nil
	}
	return nil
}

// NewSQLNullString creates a sql NullString from a given string
func NewSQLNullString(value string) sql.NullString {
	return sql.NullString{String: value, Valid: true}
}

// NewSQLNullInt64 creates a sql NullInt64
func NewSQLNullInt64(value int64) sql.NullInt64 {
	return sql.NullInt64{Int64: value, Valid: true}
}

// Identifier converts a company title into an identifier,
// strips non alphanumeric characters and makes it lowercase.
func Identifier(name string) string {
	re := regexp.MustCompile("[^a-zA-Z0-9]+")
	identifier := re.ReplaceAllString(name, "")
	return strings.ToLower(identifier)
}
