package common

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"testing"
	"time"

	"github.com/stretchr/testify/require"

	"code.justin.tv/devrel/devsite-rbac/config"
)

var ctxBkg = context.Background()

func TestPostgresConnStr(t *testing.T) {
	c := &config.RBAC{PGHost: "hhh", PGPort: "333", PGUser: "uuu", PGPassword: "ppp", PGDBName: "dbn", PGSSLEnabled: true}
	str := PostgresConnStr(c)
	require.Equal(t, "host=hhh port=333 user=uuu password=ppp dbname=dbn sslmode=require statement_timeout=2000", str)
}

func TestFirstRowIntDBField(t *testing.T) {
	list := []*Yolo{
		&Yolo{Name: "Yo", XXX_Total: 666, AnotherTotal: 777},
		&Yolo{Name: "Lo", XXX_Total: 666, AnotherTotal: 777},
	}
	var total int32
	err := FirstRowDBField(list, "_total", &total)
	require.NoError(t, err)
	require.Equal(t, int32(666), total)

	require.Equal(t, int32(666), FirstRowInt32DBField(list, "_total"))

	var anotherTotal uint64
	err = FirstRowDBField(list, "a_total", &anotherTotal)
	require.NoError(t, err)
	require.Equal(t, uint64(777), anotherTotal)

	var interfaceTotal interface{}
	err = FirstRowDBField(list, "a_total", &interfaceTotal)
	require.NoError(t, err)
	require.Equal(t, "777", fmt.Sprintf("%d", interfaceTotal))
}

func TestNewUUID(t *testing.T) {
	require.NotEmpty(t, NewUUID())
	require.NotEqual(t, NewUUID(), NewUUID(), "new uuids are unique")
}

func TestTimeNowStr(t *testing.T) {
	require.NotEmpty(t, TimeNowStr())
	currentYear := fmt.Sprintf("%d", time.Now().Year())
	require.Equal(t, currentYear, TimeNowStr()[:4])
}

func TestTimingStats(t *testing.T) {
	fakeStats := &fakeTimingStatter{}
	fun := TimingStats(fakeStats)
	duration := time.Duration(666)

	fun(ctxBkg, duration, "MyMethod", nil)
	require.Equal(t, "db.MyMethod.success", fakeStats.lastMetricName)

	fun(ctxBkg, duration, "MyMethod", errors.New("yolo crash"))
	require.Equal(t, "db.MyMethod.error", fakeStats.lastMetricName)

	fun(ctxBkg, duration, "MyMethod", sql.ErrNoRows)
	require.Equal(t, "db.MyMethod.norows", fakeStats.lastMetricName)
}

func TestIdentifier(t *testing.T) {
	for _, testCase := range []struct {
		Name     string
		Input    string
		Expected string
	}{
		{
			Name:     "hello_world",
			Input:    "h!i! t?h?e?r?e?",
			Expected: "hithere",
		},
		{
			Name:     "malicious",
			Input:    "1234::admin::all?//ignorerest...",
			Expected: "1234adminallignorerest",
		},
	} {
		t.Run(testCase.Name, func(t *testing.T) {
			actual := Identifier(testCase.Input)
			if actual != testCase.Expected {
				t.Errorf("expected %q, got %q", testCase.Expected, actual)
			}
		})
	}
}

//
// Test helpers
//

type Yolo struct {
	Name         string `db:"name"`
	XXX_Total    int32  `db:"_total"`
	AnotherTotal uint64 `db:"a_total"`
}

type fakeTimingStatter struct {
	lastMetricName string
}

func (f *fakeTimingStatter) TimingDuration(metric string, val time.Duration, rate float32) error {
	f.lastMetricName = metric
	return nil
}
