package observability

import (
	"errors"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"go.uber.org/zap"
	"go.uber.org/zap/zapcore"

	"code.justin.tv/eventbus/controlplane/internal/db"
)

// a logging implementation that just remembers the last call details
type lastCallLogger struct {
	lastDebugMsg    string
	lastDebugFields []zap.Field
	lastErrorMsg    string
	lastErrorFields []zap.Field
	lastWarnMsg     string
	lastWarnFields  []zap.Field
}

func (l *lastCallLogger) Debug(msg string, fields ...zap.Field) {
	l.lastDebugMsg = msg
	l.lastDebugFields = fields
}

func (l *lastCallLogger) Error(msg string, fields ...zap.Field) {
	l.lastErrorMsg = msg
	l.lastErrorFields = fields
}

func (l *lastCallLogger) Warn(msg string, fields ...zap.Field) {
	l.lastWarnMsg = msg
	l.lastWarnFields = fields
}

func TestObservabilityLogging(t *testing.T) {
	t.Run("Info", func(t *testing.T) {
		logger := &lastCallLogger{}
		tsEnd := time.Now()
		tsStart := tsEnd.Add(-2 * time.Second)
		fields := []zap.Field{zap.String("field1", "val1"), zap.Int("field2", 123)}
		obsData := &observabilityData{
			action: "SomeDatabaseOperation",
			start:  tsStart,
			end:    tsEnd,
			fields: fields,
			logger: logger,
			err:    nil,
		}
		log(obsData)
		assert.Nil(t, logger.lastErrorFields)
		assert.Empty(t, logger.lastErrorMsg)
		assert.NotNil(t, logger.lastDebugFields)
		assert.NotEmpty(t, logger.lastDebugMsg)
		assert.True(t, containsKeyOfType("start", zapcore.TimeType, logger.lastDebugFields))
		assert.True(t, containsKeyOfType("end", zapcore.TimeType, logger.lastDebugFields))
		assert.True(t, containsKeyOfType("duration", zapcore.DurationType, logger.lastDebugFields))
		assert.True(t, containsKeyOfType("action", zapcore.StringType, logger.lastDebugFields))
		assert.True(t, containsKeyOfType("field1", zapcore.StringType, logger.lastDebugFields))
		assert.True(t, containsKeyOfType("field2", zapcore.Int64Type, logger.lastDebugFields))
		assert.True(t, containsKeyOfType("query", zapcore.NamespaceType, logger.lastDebugFields))
		assert.False(t, containsKeyOfType("error", zapcore.ErrorType, logger.lastDebugFields))
	})
	t.Run("Error", func(t *testing.T) {
		logger := &lastCallLogger{}
		tsEnd := time.Now()
		tsStart := tsEnd.Add(-2 * time.Second)
		fields := []zap.Field{zap.String("field1", "val1"), zap.Int("field2", 123), zap.Bool("field3", true)}
		obsData := &observabilityData{
			action: "SomeDatabaseOperation",
			start:  tsStart,
			end:    tsEnd,
			fields: fields,
			logger: logger,
			err:    errors.New("this is an error"),
		}
		log(obsData)
		assert.NotNil(t, logger.lastErrorFields)
		assert.NotEmpty(t, logger.lastErrorMsg)
		assert.Nil(t, logger.lastDebugFields)
		assert.Empty(t, logger.lastDebugMsg)
		assert.True(t, containsKeyOfType("start", zapcore.TimeType, logger.lastErrorFields))
		assert.True(t, containsKeyOfType("end", zapcore.TimeType, logger.lastErrorFields))
		assert.True(t, containsKeyOfType("duration", zapcore.DurationType, logger.lastErrorFields))
		assert.True(t, containsKeyOfType("action", zapcore.StringType, logger.lastErrorFields))
		assert.True(t, containsKeyOfType("field1", zapcore.StringType, logger.lastErrorFields))
		assert.True(t, containsKeyOfType("field2", zapcore.Int64Type, logger.lastErrorFields))
		assert.True(t, containsKeyOfType("field3", zapcore.BoolType, logger.lastErrorFields))
		assert.True(t, containsKeyOfType("query", zapcore.NamespaceType, logger.lastErrorFields))
		assert.True(t, containsKeyOfType("error", zapcore.ErrorType, logger.lastErrorFields))
	})
	t.Run("Warn", func(t *testing.T) {
		logger := &lastCallLogger{}
		tsEnd := time.Now()
		tsStart := tsEnd.Add(-2 * time.Second)
		fields := []zap.Field{zap.String("field1", "val1"), zap.Int("field2", 123), zap.Bool("field3", true)}
		obsData := &observabilityData{
			action: "SomeDatabaseOperation",
			start:  tsStart,
			end:    tsEnd,
			fields: fields,
			logger: logger,
			err:    db.ErrResourceNotFound,
		}
		log(obsData)
		assert.NotNil(t, logger.lastWarnFields)
		assert.NotEmpty(t, logger.lastWarnMsg)
		assert.Nil(t, logger.lastDebugFields)
		assert.Empty(t, logger.lastDebugMsg)
		assert.Nil(t, logger.lastErrorFields)
		assert.Empty(t, logger.lastErrorMsg)
		assert.True(t, containsKeyOfType("start", zapcore.TimeType, logger.lastWarnFields))
		assert.True(t, containsKeyOfType("end", zapcore.TimeType, logger.lastWarnFields))
		assert.True(t, containsKeyOfType("duration", zapcore.DurationType, logger.lastWarnFields))
		assert.True(t, containsKeyOfType("action", zapcore.StringType, logger.lastWarnFields))
		assert.True(t, containsKeyOfType("field1", zapcore.StringType, logger.lastWarnFields))
		assert.True(t, containsKeyOfType("field2", zapcore.Int64Type, logger.lastWarnFields))
		assert.True(t, containsKeyOfType("field3", zapcore.BoolType, logger.lastWarnFields))
		assert.True(t, containsKeyOfType("query", zapcore.NamespaceType, logger.lastWarnFields))
		assert.True(t, containsKeyOfType("error", zapcore.ErrorType, logger.lastWarnFields))
	})

}

func containsKeyOfType(key string, typ zapcore.FieldType, fields []zap.Field) bool {
	for _, field := range fields {
		if field.Key == key && field.Type == typ {
			return true
		}
	}
	return false
}
