package requestlog

import (
	"context"
	"encoding/json"
	"log"
	"math/rand"
	"net/http"
	"os"
	"strconv"
	"sync"
	"time"

	logging "code.justin.tv/amzn/TwitchLogging"
	"github.com/cep21/circuit/v3"
	"github.com/pkg/errors"
	"golang.a2z.com/fulton/twirpmiddleware"

	"code.justin.tv/hygienic/messagebatch/ext/cwlogevent"
)

// GetUserID returns the User ID found in the context.
func GetUserID(ctx context.Context) string {
	userID, _ := ctx.Value(ctxKeyUserID).(string) // nolint: errcheck
	return userID
}

// GetClientID returns the Client ID found in the context.
func GetClientID(ctx context.Context) string {
	clientID, _ := ctx.Value(ctxKeyClientID).(string) // nolint: errcheck
	return clientID
}

// GetTraceID returns the Trace ID found in the context.
func GetTraceID(ctx context.Context) string {
	traceID, _ := ctx.Value(ctxKeyTraceID).(string) // nolint: errcheck
	return traceID
}

// GetRepoName returns the repository name found in the context.
func GetRepoName(ctx context.Context) string {
	repoName, _ := ctx.Value(ctxKeyRepoName).(string) // nolint: errcheck
	return repoName
}

// ContextMiddleware adds certain request info to the request's context before handling the request.
func ContextMiddleware(handler http.Handler) http.Handler {
	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
		req = req.WithContext(withReqInfo(withTrace(req.Context()), req))
		handler.ServeHTTP(rw, req)
	})
}

func withTrace(ctx context.Context) context.Context {
	return context.WithValue(ctx, ctxKeyBody, &ContextTrace{})
}

// XXX: it might be nice to add stuff from the edge/headers library if that were synced to gitfarm
func withReqInfo(ctx context.Context, req *http.Request) context.Context {
	if clientID := req.Header.Get("Client-Id"); clientID != "" { // GraphQL header
		ctx = context.WithValue(ctx, ctxKeyClientID, clientID)
	} else if clientID := req.Header.Get("Twitch-Client-Id"); clientID != "" { // Helix header
		ctx = context.WithValue(ctx, ctxKeyClientID, clientID)
	}

	if userID := req.Header.Get("User-Id"); userID != "" {
		ctx = context.WithValue(ctx, ctxKeyUserID, userID)
	}

	if traceID := req.Header.Get("X-Amz-Trace-Id"); traceID != "" {
		ctx = context.WithValue(ctx, ctxKeyTraceID, traceID)
	}

	if repoName := req.Header.Get("Twitch-Repository"); repoName != "" {
		ctx = context.WithValue(ctx, ctxKeyRepoName, repoName)
	}

	return ctx
}

// event contains the fields we're sending to cloudwatch
type event struct {
	MessageTime     time.Time `json:"msg_time"`
	MessageTimeUnix int64     `json:"msg_time_int"`
	Hostname        string    `json:"hostname"`

	Method       string        `json:"method"`
	StatusCode   string        `json:"status_code"`
	ContextTrace *ContextTrace `json:"context_trace,omitempty"`

	UserID     string `json:"user_id,omitempty"`
	ClientID   string `json:"client_id,omitempty"`
	RequestID  string `json:"request_id,omitempty"`
	TraceID    string `json:"trace_id,omitempty"`
	CallerRepo string `json:"caller_repo,omitempty"`

	ReceivedTime     time.Time `json:"recv_time"`
	ReceivedTimeUnix int64     `json:"recv_time_int"`
	DurationMS       int64     `json:"duration_ms"`
	SendDurationMS   int64     `json:"send_time_ms"`

	Error string `json:"error,omitempty"`
}

// Logger sends request logs via the CloudwatchLogBatcher.
type Logger struct {
	batcher            cwlogevent.CloudwatchLogBatcher
	currentHostname    string
	rand               *rand.Rand
	methodSLOs         map[string]time.Duration
	defaultSampleRate  float64
	sampleRates        map[string]float64
	sampleClientErrors bool
	randMutex          sync.Mutex
	logClientID        bool
	logUserID          bool
}

// Configs holds config values for Logger struct
type Configs struct {
	LogGroupName   string         // required: name of the Cloudwatch Log Group to write logs to
	LogGroupRegion string         // required: region of the Cloudwatch Log Group to write logs to
	ErrorLogger    logging.Logger // required: logger for errors writing request logs

	CloudwatchCircuit *circuit.Circuit // circuit for wrapping the Cloudwatch Log publishing

	MethodSLOs         map[string]time.Duration // map of method names to their SLOs
	DefaultSampleRate  float64                  // log this percentage of requests. Default: 0.001 = 0.1%.
	MethodSampleRate   map[string]float64       // override the default sample rate per method. Map is {Method Name => Sample Rate}
	SampleClientErrors bool                     // by default, all errors are logged. if true, client errors (HTTP 400-499) will be sampled.
	LogClientID        bool                     // should ClientID (from the Client-ID header) be logged
	LogUserID          bool                     // should UserID (from the User-ID header) be logged
}

func (c *Configs) validate() error {
	if c.LogGroupName == "" {
		return errors.New("log group name is required")
	}

	if c.LogGroupRegion == "" {
		return errors.New("log group region is required")
	}

	if c.ErrorLogger == nil {
		return errors.New("error logger must be non-nil")
	}

	if c.DefaultSampleRate <= 0 {
		c.DefaultSampleRate = 0.001
	}

	if c.MethodSampleRate == nil {
		c.MethodSampleRate = map[string]float64{}
	}

	return nil
}

// New initialises a Logger and starts batching logs.
// Fails if
// - config is invalid
// - hostname cannot be loaded
// - batcher cannot be initialised
// The caller is responsible for closing the Logger.
func New(conf Configs) (*Logger, error) {
	if err := conf.validate(); err != nil {
		return nil, errors.Wrap(err, "invalid configuration")
	}

	batcher, err := initializeBatcher(conf.ErrorLogger, conf.CloudwatchCircuit, conf.LogGroupName, conf.LogGroupRegion)
	if err != nil {
		return nil, errors.Wrap(err, "failed to start log batcher")
	}

	hostname, err := os.Hostname()
	if err != nil {
		return nil, errors.Wrap(err, "unable to resolve hostname")
	}

	return &Logger{
		rand:               rand.New(rand.NewSource(time.Now().UnixNano())),
		batcher:            batcher,
		randMutex:          sync.Mutex{},
		currentHostname:    hostname,
		logClientID:        conf.LogClientID,
		logUserID:          conf.LogUserID,
		methodSLOs:         conf.MethodSLOs,
		defaultSampleRate:  conf.DefaultSampleRate,
		sampleRates:        conf.MethodSampleRate,
		sampleClientErrors: conf.SampleClientErrors,
	}, nil
}

// Close should be stopped to stop the Logger.
// Does not block.
func (l *Logger) Close() error {
	return l.batcher.Close()
}

func (l *Logger) fillDefaults(ctx context.Context, event *event) *event {
	if event.MessageTime.IsZero() {
		event.MessageTime = time.Now()
	}

	if event.MessageTimeUnix == 0 {
		event.MessageTimeUnix = event.MessageTime.Unix()
	}

	if event.Hostname == "" {
		event.Hostname = l.currentHostname
	}

	if l.logUserID {
		event.UserID = GetUserID(ctx)
	}

	if l.logClientID {
		event.ClientID = GetClientID(ctx)
	}

	event.RequestID = twirpmiddleware.GetRequestID(ctx)
	event.TraceID = GetTraceID(ctx)
	event.CallerRepo = GetRepoName(ctx)

	return event
}

// WithEventError commits an error to the context
func WithEventError(ctx context.Context, err error) context.Context {
	ctx = context.WithValue(ctx, eventErrKey, err.Error())
	return ctx
}

// WithEventStart commits the event start time and an empty trace to context
func WithEventStart(ctx context.Context) context.Context {
	ctx = context.WithValue(withTrace(ctx), eventStartTimeKey, time.Now())
	return ctx
}

// WithEventHandled commits the event handling time to the context
func WithEventHandled(ctx context.Context) context.Context {
	ctx = context.WithValue(ctx, eventHandledTimeKey, time.Now())
	return ctx
}

// LogEvent processes the constructed context and emits a log event
func (l *Logger) LogEvent(ctx context.Context, method string) {
	l.logEventImpl(ctx, method, "")
}

// logEventImpl processes the constructed context (calculates timings, etc.) and emits a log event, with (optional) statusCode
func (l *Logger) logEventImpl(ctx context.Context, method string, statusCode string) {
	event := &event{
		MessageTime:     time.Now(),
		MessageTimeUnix: time.Now().UnixNano(),
		Hostname:        l.currentHostname,
		Method:          method,
		StatusCode:      statusCode,
	}

	if recvTime, ok := ctx.Value(eventStartTimeKey).(time.Time); ok {
		duration := time.Since(recvTime)
		event.ReceivedTime = recvTime
		event.ReceivedTimeUnix = recvTime.UnixNano()
		event.DurationMS = duration.Nanoseconds() / time.Millisecond.Nanoseconds()
	}

	if handleTime, ok := ctx.Value(eventHandledTimeKey).(time.Time); ok {
		sendDuration := time.Since(handleTime)
		event.SendDurationMS = sendDuration.Nanoseconds() / time.Millisecond.Nanoseconds()
	}

	if errMsg, ok := ctx.Value(eventErrKey).(string); ok {
		event.Error = errMsg
	}

	l.sendEvent(ctx, event)
}

// shouldSend returns true if an event should be logged.
func (l *Logger) shouldSend(event *event) bool {
	if trace := event.ContextTrace; trace != nil && trace.forceLog {
		return true
	}

	if event.Error != "" {
		statusCode, err := strconv.Atoi(event.StatusCode)
		if err != nil || // always log if we can't parse status code
			(statusCode > 499) || // always log if status code is 500+
			!l.sampleClientErrors { // always log if SampleClientErrors is not set
			return true
		}
	}

	// always allow long requests
	if !l.passesSLO(event.DurationMS, event.Method) {
		return true
	}

	// don't panic if badly configured
	if l.rand == nil {
		log.Println("WARNING: logger configured with no random source")
		return true
	}

	// rand using non-default source is not threadsafe
	l.randMutex.Lock()
	defer l.randMutex.Unlock()

	// allow only sampleRate% of requests
	sampleRate := l.defaultSampleRate
	if overrideRate, ok := l.sampleRates[event.Method]; ok {
		sampleRate = overrideRate
	}

	return l.rand.Float64() < sampleRate
}

// sendEvent sends an event to cloudwatch logs as a JSON blob.
func (l *Logger) sendEvent(ctx context.Context, event *event) {
	event.ContextTrace = getTraceOrNil(ctx)

	if !l.shouldSend(event) {
		return
	}

	event = l.fillDefaults(ctx, event)

	msg, err := json.Marshal(event)
	if err != nil {
		l.batcher.Log.Log("err", err, "could not marshal request log")
	}

	l.batcher.Event(string(msg), event.MessageTime)
}

// passesSLO returns true if an event is within SLO (based on method name)
func (l *Logger) passesSLO(durationMS int64, methodName string) bool {
	duration := time.Duration(durationMS) * time.Millisecond

	if len(l.methodSLOs) == 0 {
		return true
	}

	if slo, ok := l.methodSLOs[methodName]; ok {
		return duration <= slo
	}

	return false
}
