package metrics

import (
	"fmt"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/cloudwatch"
	"github.com/aws/aws-sdk-go/service/cloudwatch/cloudwatchiface"
	log "github.com/sirupsen/logrus"
)

const (
	rexNamespace        = "rex"
	STAGE               = "STAGE"
	COUNT               = "Count"
	AVAILABILITY_FORMAT = "%s.Availability"
)

type Config struct {
	Stage                          string
	AwsRegion                      string
	BufferSize                     int
	BatchSize                      int
	FlushInterval                  time.Duration
	FlushPollCheckDelay            time.Duration
	BufferEmergencyFlushPercentage float64
}

// rexMetricLogger wrapper struct that uses Cloudwatch to post Metrics data
type rexMetricLogger struct {
	metricsBuffer chan *cloudwatch.MetricDatum
	config        Config
}

// rexMetricFlusher struct containing fields required to run metric flushing
type rexMetricFlusher struct {
	cloudWatch    cloudwatchiface.CloudWatchAPI
	metricsBuffer chan *cloudwatch.MetricDatum
	config        Config
	lastFlushTime time.Time
}

// IMetricLogger API for creating metrics from the service
type IMetricLogger interface {
	// LogDurationMetric captures a metric duration of time
	LogDurationMetric(name string, duration time.Duration)
	// LogDurationSinceMetric captures the duration since a given event timestamp
	LogDurationSinceMetric(name string, timeOfEvent time.Time)
	// LogCountMetric captures count metrics
	LogCountMetric(name string, count float64)
	// AddDependencyCallMetrics adds availability, volume, and latency metrics (in ms) for http client calls
	AddHttpDependencyCallMetrics(statusCode int, duration time.Duration, dependencyName string)
	// AddTwirpDependencyCallMetrics adds availability, volume, and latency metrics (in ms) for Twirp client calls
	AddTwirpDependencyCallMetrics(potentialError error, duration time.Duration, dependencyName string)
	// LogCountMetricsForStatusCode captures metrics for HTTP status codes
	LogCountMetricsForStatusCode(statusCode int, metricName string)
}

// NilMetricsLogger equivalent of NullMetricsFactory
type NilMetricsLogger struct{}

func (n *NilMetricsLogger) LogDurationMetric(name string, duration time.Duration) {
	return
}
func (n *NilMetricsLogger) LogDurationSinceMetric(name string, timeOfEvent time.Time) {
	return
}
func (n *NilMetricsLogger) LogCountMetric(name string, count float64) {
	return
}
func (n *NilMetricsLogger) AddHttpDependencyCallMetrics(statusCode int, duration time.Duration, dependencyName string) {
	return
}
func (n *NilMetricsLogger) AddTwirpDependencyCallMetrics(potentialError error, duration time.Duration, dependencyName string) {
	return
}
func (n *NilMetricsLogger) LogCountMetricsForStatusCode(statusCode int, metricName string) {
	return
}

// NewNilMetricsLogger creates a NilMetricsLogger
func NewNilMetricsLogger() IMetricLogger {
	return &NilMetricsLogger{}
}

// IMetricFlusher interface for flushing metrics
type IMetricFlusher interface {
	FlushMetrics()
	FlushMetricsAtInterval()
	ShouldFlush() bool
}

// New creates an implementation of IMetricLogger from the Metric Config
func New(config Config) (IMetricLogger, IMetricFlusher) {
	awsConfig := &aws.Config{
		Region: aws.String(config.AwsRegion),
	}
	client := cloudwatch.New(session.New(), awsConfig)

	return NewFromCloudwatchClient(config, client)
}

// NewFromCloudwatchClient creates an IMetricLogger using CW Client
func NewFromCloudwatchClient(config Config, cloudwatchClient cloudwatchiface.CloudWatchAPI) (IMetricLogger, IMetricFlusher) {
	metricsBufferChannel := make(chan *cloudwatch.MetricDatum, config.BufferSize)

	logger := &rexMetricLogger{
		config:        config,
		metricsBuffer: metricsBufferChannel,
	}

	flusher := &rexMetricFlusher{
		cloudWatch:    cloudwatchClient,
		metricsBuffer: metricsBufferChannel,
		config:        config,
		lastFlushTime: time.Now(),
	}

	return logger, flusher
}

// LogDurationSinceMetric captures the duration since a given event timestamp
func (this *rexMetricLogger) LogDurationSinceMetric(name string, timeOfEvent time.Time) {
	delta := time.Since(timeOfEvent)
	this.LogDurationMetric(name, delta)
}

// LogDurationMetric captures a metric duration of time
func (this *rexMetricLogger) LogDurationMetric(name string, duration time.Duration) {
	metricDatum := &cloudwatch.MetricDatum{
		MetricName: aws.String(name),
		Dimensions: []*cloudwatch.Dimension{
			{
				Name:  aws.String("STAGE"),
				Value: aws.String(this.config.Stage),
			},
		},
		Timestamp: aws.Time(time.Now().UTC()),
		Unit:      aws.String("Seconds"),
		Value:     aws.Float64(duration.Seconds()),
	}
	if len(this.metricsBuffer) < this.config.BufferSize {
		this.metricsBuffer <- metricDatum
	} else {
		log.Error("Metrics buffer at capacity. Additional metrics will be dropped")
	}
}

// LogCountMetric captures count metrics
func (this *rexMetricLogger) LogCountMetric(name string, count float64) {
	metricDatum := &cloudwatch.MetricDatum{
		MetricName: aws.String(name),
		Dimensions: []*cloudwatch.Dimension{
			{
				Name:  aws.String(STAGE),
				Value: aws.String(this.config.Stage),
			},
		},
		Timestamp: aws.Time(time.Now().UTC()),
		Unit:      aws.String(COUNT),
		Value:     aws.Float64(count),
	}
	if len(this.metricsBuffer) < this.config.BufferSize {
		this.metricsBuffer <- metricDatum
	} else {
		log.Error("Metrics buffer at capacity. Additional metrics will be dropped")
	}
}

func (this *rexMetricFlusher) FlushMetricsAtInterval() {
	this.lastFlushTime = time.Now()
	for {
		// Start the polling delay timer
		timer := time.After(this.config.FlushPollCheckDelay)

		if this.ShouldFlush() {
			this.FlushMetrics()
			this.lastFlushTime = time.Now()
		}

		// Wait until the polling delay timer expires
		<-timer
	}
}

func (this *rexMetricFlusher) ShouldFlush() bool {
	return this.isNextFlushInterval() || this.isBufferApproachingCapacity()
}

func (this *rexMetricFlusher) isNextFlushInterval() bool {
	timeSinceLastFlush := time.Now().Sub(this.lastFlushTime)
	return timeSinceLastFlush > this.config.FlushInterval
}

func (this *rexMetricFlusher) isBufferApproachingCapacity() bool {
	bufferUsage := float64(len(this.metricsBuffer)) / float64(this.config.BufferSize)
	approachingCapacity := bufferUsage >= this.config.BufferEmergencyFlushPercentage
	if approachingCapacity {
		log.WithField("bufferUsage", bufferUsage).Error("The metrics buffer is approaching capacity. Executing emergency flush.")
	}
	return approachingCapacity
}

func (this *rexMetricFlusher) FlushMetrics() {
	numMetricsToFlush := len(this.metricsBuffer)

	numSucceeded, numFailed := 0, 0
	metricBatch := make([]*cloudwatch.MetricDatum, 0, this.config.BatchSize)
	for i := 0; i < numMetricsToFlush; i++ {

		metric := <-this.metricsBuffer
		metricBatch = append(metricBatch, metric)

		// Send the batch if it is at capacity or this is the last metric
		if len(metricBatch) >= this.config.BatchSize || i == (numMetricsToFlush-1) {
			err := this.postMetricBatch(metricBatch)
			if err != nil {
				log.WithError(err).WithField("metricNum", len(metricBatch)).Error("Failed to send metrics while communicating with cloudwatch")
				numFailed += len(metricBatch)
			} else {
				numSucceeded += len(metricBatch)
			}
			metricBatch = make([]*cloudwatch.MetricDatum, 0, this.config.BatchSize)
		}
	}

	if numFailed > 0 {
		log.Infof("Finished flushing metrics, but encountered some errors. Succeeded: [%d], Failed: [%d]", numSucceeded, numFailed)
	}
}

func (this *rexMetricFlusher) postMetricBatch(metricsBatch []*cloudwatch.MetricDatum) error {
	request := &cloudwatch.PutMetricDataInput{
		MetricData: metricsBatch,
		Namespace:  aws.String(rexNamespace),
	}
	_, err := this.cloudWatch.PutMetricData(request)
	return err
}

// AddHttpDependencyCallMetrics adds availability, volume, and latency metrics (in ms) for http client calls
func (this *rexMetricLogger) AddHttpDependencyCallMetrics(statusCode int, duration time.Duration, dependencyName string) {
	this.LogDurationMetric(dependencyName+".Latency", duration)
	this.LogCountMetric(dependencyName+".Volume", 1.0)
	this.LogCountMetricsForStatusCode(statusCode, dependencyName)
}

// AddTwirpDependencyCallMetrics adds availability, volume, and latency metrics (in ms) for Twirp client calls
func (this *rexMetricLogger) AddTwirpDependencyCallMetrics(potentialError error, duration time.Duration, dependencyName string) {
	this.LogDurationMetric(dependencyName+".Latency", duration)
	this.LogCountMetric(dependencyName+".Volume", 1.0)

	var availability float64
	if potentialError == nil {
		availability = 1.0
	}
	this.LogCountMetric(fmt.Sprintf(AVAILABILITY_FORMAT, dependencyName), availability)
}

func (this *rexMetricLogger) LogCountMetricsForStatusCode(statusCode int, metricName string) {
	switch {
	case statusCode < 200:
		this.LogCountMetric(fmt.Sprintf("%s.1xx", metricName), 1.0)
	case statusCode >= 200 && statusCode < 300:
		this.LogCountMetric(fmt.Sprintf("%s.2xx", metricName), 1.0)
	case statusCode >= 300 && statusCode < 400:
		this.LogCountMetric(fmt.Sprintf("%s.3xx", metricName), 1.0)
	case statusCode >= 400 && statusCode < 500:
		this.LogCountMetric(fmt.Sprintf("%s.4xx", metricName), 1.0)
	case statusCode >= 500:
		this.LogCountMetric(fmt.Sprintf("%s.5xx", metricName), 1.0)
	}
	// Availability For API
	this.LogCountMetric(fmt.Sprintf(AVAILABILITY_FORMAT, metricName), float64(calculateAvailability(statusCode)))
}

// calculateAvailability using status code
func calculateAvailability(statusCode int) int {
	if statusCode < 500 {
		return 1
	}
	return 0
}
