package lib

import (
	"context"
	"encoding/json"
	"math/rand"
	"os"
	"sync"
	"time"

	"code.justin.tv/hygienic/messagebatch"
	"code.justin.tv/hygienic/messagebatch/ext/cwlogevent"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/endpoints"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
	"github.com/cep21/circuit/v3"
	"github.com/pkg/errors"

	logging "code.justin.tv/amzn/TwitchLogging"
)

// AccessLog adds cloudwatch log events so we can track access.
type AccessLog struct {
	batcher            cwlogevent.CloudwatchLogBatcher
	rand               *rand.Rand
	randMu             sync.Mutex
	defaultSampleRate  float64
	endpointSampleRate map[string]float64
}

// Configs are the configurations for AccessLog
type Configs struct {
	LogGroupName string
	// DefaultSampleRate defaults to 0.001 or 0.1%
	DefaultSampleRate float64
	// EndpointSampleRate can override DefaultSampleRate per endpoint, set to 0 if you want to disable logging for a particular endpoint
	EndpointSampleRate map[string]float64
}

// New creates a new AccessLog
func New(logger logging.Logger, cwCircuit *circuit.Circuit, configs Configs) (*AccessLog, error) {
	if configs.DefaultSampleRate == 0 {
		configs.DefaultSampleRate = 0.001
	}

	batcher, err := initializeBatcher(logger, cwCircuit, configs.LogGroupName)
	if err != nil {
		return nil, errors.Wrap(err, "cannot create batcher")
	}

	return &AccessLog{
		batcher:            batcher,
		rand:               rand.New(rand.NewSource(time.Now().UnixNano())), //nolint
		defaultSampleRate:  configs.DefaultSampleRate,
		endpointSampleRate: configs.EndpointSampleRate,
	}, nil
}

// TODO: always log long running request?
func (a *AccessLog) shouldFilter(log *Message) bool {
	if rate, exist := a.endpointSampleRate[log.Method]; exist {
		return a.filterByRate(rate)
	}
	return a.filterByRate(a.defaultSampleRate)
}

func (a *AccessLog) filterByRate(rate float64) bool {
	if rate == 0 {
		return true
	} else if rate == 1 {
		return false
	} else {
		a.randMu.Lock()
		defer a.randMu.Unlock()
		return a.rand.Float64() > rate
	}
}

// Log logs an entry of access log to CloudWatch
func (a *AccessLog) Log(ctx context.Context, msg *Message) {
	if a.shouldFilter(msg) {
		return
	}

	if traceStrings, traceInts, floats, exists := traceValues(ctx); exists {
		if len(traceStrings) != 0 {
			msg.TraceStrings = traceStrings
		}
		if len(traceInts) != 0 {
			msg.TraceInts = traceInts
		}
		if len(floats) != 0 {
			msg.TraceFloats = floats
		}
	}

	s, err := json.Marshal(msg)
	if err != nil {
		a.batcher.Log.Log("could not marshal access msg", "err", err)
	}

	a.batcher.Event(string(s), msg.MsgTime)
}

// initializeBatcher initializes a CloudwatchLogBatcher.
// twitchLogger is the logger object used to instantiate the batcher.
// cwCircuit are the circuits defined for batcher calls.
func initializeBatcher(twitchLogger logging.Logger, cwCircuit *circuit.Circuit, logGroupName string) (cwlogevent.CloudwatchLogBatcher, error) {
	awsConf := &aws.Config{
		// Current region is pre-defined in the execution environment
		Region:              aws.String(os.Getenv("AWS_DEFAULT_REGION")),
		STSRegionalEndpoint: endpoints.RegionalSTSEndpoint,
	}

	sess, err := session.NewSession(awsConf)
	if err != nil {
		return cwlogevent.CloudwatchLogBatcher{}, errors.Wrap(err, "failed to start AWS session")
	}

	logWrapper := &logWrapper{
		prefix: "accessLogger",
		logger: twitchLogger,
	}

	batcher := cwlogevent.CloudwatchLogBatcher{
		Batcher: messagebatch.Batcher{
			Log:    logWrapper,
			Events: make(chan interface{}, 1000),
		},
		Config: &cwlogevent.Config{
			LogGroupName: logGroupName,
		},
		Client:  cloudwatchlogs.New(sess, awsConf),
		Circuit: cwCircuit,
	}

	err = batcher.Setup()
	if err != nil {
		return batcher, err
	}

	go func() {
		err := batcher.Start()
		if err != nil {
			logWrapper.Log(err)
		}
	}()

	return batcher, nil
}

type logWrapper struct {
	prefix string
	logger logging.Logger
}

func (d *logWrapper) Log(keyvals ...interface{}) {
	d.logger.Log(d.prefix, keyvals...)
}
