package beater

import (
	"fmt"
	"time"
	"math/rand"

	"github.com/elastic/beats/libbeat/beat"
	"github.com/elastic/beats/libbeat/cfgfile"
	"github.com/elastic/beats/libbeat/common"
	"github.com/elastic/beats/libbeat/logp"
	"github.com/elastic/beats/libbeat/publisher"
)

type Beater struct {
	events    publisher.Client
	startTime common.Time
	Beat      *beat.Beat
	count     int64

	BeaterConfig ConfigSettings

	waitTime   time.Duration
	msgRate    int64
	justGoFast bool
	dateRange  int
	outputTag  string

	warningThreshold int

	periodCounts []float64

	done chan struct{}
}

type BeaterConfig struct {
	Rate                 *int64  `config:"rate"`
	RateWarningThreshold *int    `config:"rate_warning_threshold"`
	JustGoFast           *bool   `config:"just_go_fast"`
	DateRange            *int    `config:"date_range"`
	OutputTag            *string `config:"output_tag"`
}

type ConfigSettings struct {
	Input BeaterConfig
}

func New() *Beater {
	return &Beater{}
}

func (b *Beater) Config(bb *beat.Beat) error {
	err := cfgfile.Read(&b.BeaterConfig, "")

	if err != nil {
		logp.Err("Error reading config file: %v", err)
		return err
	}

	if b.BeaterConfig.Input.Rate != nil {
		b.msgRate = *b.BeaterConfig.Input.Rate
	} else {
		b.msgRate = 1
	}

	if b.BeaterConfig.Input.RateWarningThreshold != nil {
		b.warningThreshold = *b.BeaterConfig.Input.RateWarningThreshold
	} else {
		b.warningThreshold = 5
	}

	if b.BeaterConfig.Input.JustGoFast != nil {
		b.justGoFast = *b.BeaterConfig.Input.JustGoFast
	} else {
		b.justGoFast = false
	}

	if b.BeaterConfig.Input.DateRange != nil {
		b.dateRange = *b.BeaterConfig.Input.DateRange
	} else {
		b.dateRange = 0
	}

	if b.BeaterConfig.Input.OutputTag != nil {
		b.outputTag = *b.BeaterConfig.Input.OutputTag
	} else {
		b.outputTag = "none"
	}

	return nil
}

func (b *Beater) Setup(bb *beat.Beat) error {
	logp.Err("Test setup!")
	b.Beat = bb
	b.count = 0
	b.periodCounts = []float64{0,0,0,0,0,0,0,0,0,0}
	b.done = make(chan struct{})
	return nil
}

func (b *Beater) Run(bb *beat.Beat) error {
	var err error

	b.startTime = common.Time(time.Now())

	waitTime := time.Second / time.Duration(b.msgRate)

	if b.justGoFast {
		fmt.Printf("Started beating superduper fast...\n")
	} else {
		fmt.Printf("Started beating at %d per second... (%d)\n", b.msgRate, waitTime)
	}

	tickerTime := waitTime

	// If we gotta go fast, just go fast yo
	if b.justGoFast {
		tickerTime = time.Nanosecond
	}

	ticker := time.NewTicker(tickerTime)
	secondTicker := time.NewTicker(time.Second)
	defer ticker.Stop()
	defer secondTicker.Stop()

	var periodCount float64 = 0
	var periodCountCursor int = 0
	var showAverage bool = false

	for {
		select {
		case <-b.done:
			return nil
		case <-secondTicker.C:
			if b.justGoFast {
				avgStr := ""
				if showAverage {
					avgStr = fmt.Sprintf(" %.0f/s avg", avg(b.periodCounts))
				}
				fmt.Printf("Logged %.0f messages over the last second!%s\n", periodCount, avgStr)
			} else if int64(periodCount) != b.msgRate {
				difference := (1 - (float64(periodCount) / float64(b.msgRate))) * 100
				if difference > float64(b.warningThreshold) {
					fmt.Printf("Logging %.2f%% slower than expected. Sent %.0f of %d in 1s.\n", difference, periodCount, b.msgRate)
				}
			}

			b.periodCounts[periodCountCursor] = periodCount
			if !showAverage && periodCountCursor == 9 {
				showAverage = true
			}
			periodCountCursor = (periodCountCursor + 1) % 10
			periodCount = 0
			continue
		case <-ticker.C:
		}

		err = b.sendMessage()
		periodCount += 1
		if err != nil {
			logp.Err("Error sending message: %v", err)
		}
	}
}

func (b *Beater) Cleanup(bb *beat.Beat) error {
	fmt.Printf("Ending at %d messages", b.count)
	return nil
}

func (b *Beater) Stop() {
	close(b.done)
}

func (b *Beater) sendMessage() error {

	messageEnd := fmt.Sprintf("@ %d/s", b.msgRate)
	if b.justGoFast {
		messageEnd = "going as fast as possible"
	}

  timestamp := time.Now()

	if b.dateRange > 0 {
		// If we have a date range, subtract a random amount of days up to the range.
		multiplier := rand.Int() % b.dateRange
		dayDuration := time.Hour * time.Duration(24 * multiplier)
		timestamp = timestamp.Add(time.Duration(0) - dayDuration)
	}

	event := common.MapStr{
		"@timestamp": common.Time(timestamp),
		"message":    fmt.Sprintf("This is message number %d %s! It's about this big because why not lol right? 128 b-ish seems fineish or something.", b.count, messageEnd),
		"type":       "loadtest",
		"test_start": b.startTime,
		"fields":     common.MapStr{
			"output": b.outputTag,
		},
	}

	b.count += 1

	b.Beat.Events.PublishEvent(event)
	return nil
}

func avg(arr []float64) float64 {
	var total float64 = 0
	for _, value := range arr {
		total += value
	}
	return total / float64(len(arr))
}
