package main

import (
	"crypto/tls"
	"fmt"
	"log"
	"math/rand"
	"net/url"
	"strconv"
	"sync/atomic"
	"time"

	"code.justin.tv/devhub/lib-lifecycle/src/lifecycle"

	"code.justin.tv/devhub/e2ml/libs/logging"
	"code.justin.tv/devhub/e2ml/libs/metrics"
	"code.justin.tv/devhub/e2ml/libs/metrics/datadog"
	"code.justin.tv/devhub/e2ml/libs/metrics/logged"
	"code.justin.tv/devhub/e2ml/libs/stream"
	"code.justin.tv/devhub/e2ml/libs/stream/auth/empty"
	"code.justin.tv/devhub/e2ml/libs/stream/auth/s2s"
	"code.justin.tv/devhub/e2ml/libs/timeout"
	"code.justin.tv/devhub/e2ml/libs/websocket"

	"code.justin.tv/devhub/e2ml-loadtest/config"
	"code.justin.tv/devhub/e2ml-loadtest/emlclient"
)

func init() {
	rand.Seed(time.Now().Unix())
}

const addrFilter = "n"

func main() {
	conf := config.ParseFlagsEnvConfig()
	envConf := conf.TargetENVConf() // validates TargetENV
	dynamicLogger := emlclient.NewDynamicLogger("default", logging.Info)
	logger := dynamicLogger.Log
	logger(logging.Info, fmt.Sprintf("Conf: %+v", conf))

	// Use a different logger for registry, so we can use a higher log level to avoid noise from their traces
	dynamicLoggerReg := emlclient.NewDynamicLogger("reg", logging.Info)
	loggerReg := dynamicLoggerReg.Log

	// Use the same address namespace. The "n" filter will be used to evenly distribute writers and listeners
	baseAddr, err := stream.NewAddress("loadtest", 1, map[string]string{})
	if err != nil {
		log.Fatalf("invalid address: %v", err)
	}

	// Manager to register closing hooks and make sure they are all closed at the end of the program.
	shutdownMgr := lifecycle.NewManager()
	defer exit(shutdownMgr, logger) // grateful shutdowns

	// Web Socket Connection Factory
	wssFactory := websocket.NewClientResolver(&websocket.Settings{
		Certs:     &tls.Config{InsecureSkipVerify: conf.Insecure},
		Lifecycle: shutdownMgr,
		Logger:    logger,
		Timeout:   timeout.NewConstantSampler(20 * time.Second),
	})

	// Metrics
	stats, tickRate, err := buildMetricsTracker(envConf, loggerReg)
	if err != nil {
		log.Fatalf("datadog metrics tracker: %v", err)
	}
	shutdownMgr.RegisterHook(stats, stats.Close)
	shutdownMgr.TickUntilClosed(stats.Tick, tickRate) // flush metric aggregations to the agent on tick

	// stats for tracking our activity
	const (
		statName    = "tasks.finished"
		statOK      = "result:ok"
		statFailed  = "result:failed"
		statStarted = "action:start"
		statStopped = "action:stop"
		statUpdated = "action:update"
	)
	spawned := stats.Count("tasks.spawned", []string{})
	startedOk := stats.Count(statName, []string{statStarted, statOK})
	startedFailed := stats.Count(statName, []string{statStarted, statFailed})
	stoppedOk := stats.Count(statName, []string{statStopped, statOK})
	stoppedFailed := stats.Count(statName, []string{statStopped, statFailed})
	updatedOk := stats.Count(statName, []string{statUpdated, statOK})
	updatedFailed := stats.Count(statName, []string{statUpdated, statFailed})

	url, err := url.Parse(envConf.URL)
	if err != nil {
		log.Fatalf("Unable to parse connectionURL: %v", err)
	}

	// Message Factory: shared on writers and listeners in this process, used to identify elapsed time
	prefix := make([]byte, 32) // random prefix to idenfity this process, to ensure elapsed time is only reported on our own messages
	rand.Read(prefix)
	msgFactory := emlclient.NewMessageFactory(0, 0, prefix)
	regFactory := emlclient.NewRegistryFactory(url, wssFactory, buildAuthSource(conf), loggerReg, stats)

	// Dynamic Configuration Manager, checks if it was updated from local file or remote s3 bucket
	dynamicConfMngr, err := config.NewDynamicConfigMngr(conf.DynamicConfS3Bucket, conf.DynamicConfFile)
	if err != nil {
		log.Fatalf("DynamicConf: init error: %v", err)
	}
	var dynamicConf *config.DynamicConfig

	numAddrs := int32(0)
	numAddrsOffset := int32(0)
	desiredWriters := 0
	desiredListeners := 0

	listeners := make([]*emlclient.Listener, 1000)
	writers := make([]*emlclient.Writer, 1000)

	type connReq struct {
		op   string
		task emlclient.Task
	}
	retryChan := make(chan connReq, 10000)
	connChan := make(chan connReq, 10000)
	for i := 0; i < conf.LimitConnecting; i++ {
		go func() {
			for req := range connChan {
				switch req.op {
				case "start":
					addr := newAddrSpread(baseAddr, req.task.ID(), atomic.LoadInt32(&numAddrs), atomic.LoadInt32(&numAddrsOffset))
					if req.task.Start(addr) {
						startedOk.Add(1)
					} else {
						startedFailed.Add(1)
						retryChan <- req
					}
				case "stop":
					if req.task.Stop() {
						stoppedOk.Add(1)
					} else {
						stoppedFailed.Add(1)
						retryChan <- req
					}
				case "update":
					addr := newAddrSpread(baseAddr, req.task.ID(), atomic.LoadInt32(&numAddrs), atomic.LoadInt32(&numAddrsOffset))
					if req.task.Start(addr) {
						updatedOk.Add(1)
					} else {
						updatedFailed.Add(1)
						retryChan <- req
					}
				default:
					logger(logging.Debug, "Unknown task operation", req.op)
				}
			}
		}()
	}

	// The counters below track the number of active connections and the number of connections that need to have their addresses changed.
	// The content of each array is expected to have segments like the following:
	// | <- stale? -> | <- started? -> | <- desired/stopped? -> | uninitialized
	//
	// When an address change happens, all started addresses are marked stale and they are corrected from high index to low -- this ensures
	// that even if the started count increases only items that were stale at the point of address change will be reset

	startedWriters := 0
	startedListeners := 0

	staleWriters := 0
	staleListeners := 0

	// Main loop
	c := time.Tick(10 * time.Millisecond)
	for now := range c {
		// Prompt writers to send any messages that are due
		for i := 0; i < startedWriters; i++ {
			writers[i].Tick(now)
		}

		// Capacity represents a work budget that is reset every cycle; it is meant
		// to act as a CPU throttle so that writers can reliably send their messages
		// at the correct rate; right now all tasks are equally weighted with a cost
		// of 1 per task
		capacity := conf.LimitAddedPerTick

		// The order of operations below has been chosen to optimize stability and
		// the ability to rapidly scale down a runaway test
		// (1) retry failures that are not obsolete, dropping any that no longer apply
		// (2) scale writers down
		// (3) remap writers to updated address counts
		// (4) remap listeners to updated address counts
		// (5) scale listeners up/down
		// (6) scale writers up

	retryLoop:
		for capacity > 0 { // (1)
			select {
			case req, found := <-retryChan:
				if !found {
					break retryLoop
				}
				desired := desiredListeners
				if req.task.Type() == "writer" {
					desired = desiredWriters
				}
				// allow retry of stop if ID > desired, anything else if ID <= desired
				if (req.op == "stop") == (req.task.ID() > desired) {
					connChan <- req
				}
			default:
				break retryLoop
			}
			capacity -= 1
		}

		if startedWriters > desiredWriters && capacity > 0 { // (2)
			target := max(desiredWriters, startedWriters-capacity)
			for i := startedWriters - 1; i >= target; i-- {
				connChan <- connReq{op: "stop", task: writers[i]}
			}
			capacity -= startedWriters - target
			startedWriters = target
		}

		if staleWriters > 0 && capacity > 0 { // (3)
			staleWriters = min(staleWriters, startedWriters)
			end := max(0, staleWriters-capacity)

			for i := staleWriters - 1; i >= end; i-- {
				connChan <- connReq{op: "update", task: writers[i]}
			}
			capacity -= staleWriters - end
			staleWriters = end
		}

		if staleListeners > 0 && capacity > 0 { // (4)
			staleListeners = min(staleListeners, startedListeners)
			end := max(0, staleListeners-capacity)

			for i := staleListeners - 1; i >= end; i-- {
				connChan <- connReq{op: "update", task: listeners[i]}
			}
			capacity -= staleListeners - end
			staleListeners = end
		}

		if startedListeners != desiredListeners && capacity > 0 { // (5)
			target := desiredListeners
			if startedListeners > target {
				target = max(target, startedListeners-capacity)
				for i := startedListeners - 1; i >= target; i-- {
					connChan <- connReq{op: "stop", task: listeners[i]}
				}
			} else {
				target = min(target, startedListeners+capacity)
				for len(listeners) < target {
					prev := listeners
					listeners = make([]*emlclient.Listener, 2*len(prev))
					copy(listeners, prev)
				}
				for i := startedListeners; i < target; i++ {
					if listeners[i] == nil {
						listeners[i] = emlclient.NewListener(i, regFactory.Create(), msgFactory, logger, stats)
					}
					connChan <- connReq{op: "start", task: listeners[i]}
				}
			}
			capacity -= abs(target - startedWriters)
			startedListeners = target
		}

		if startedWriters < desiredWriters && capacity > 0 { // (6)
			target := min(desiredWriters, startedWriters+capacity)
			for len(writers) < target {
				prev := writers
				writers = make([]*emlclient.Writer, 2*len(prev))
				copy(writers, prev)
			}
			for i := startedWriters; i < target; i++ {
				if writers[i] == nil {
					writers[i] = emlclient.NewWriter(i, regFactory.Create(), msgFactory, logger)
				}
				connChan <- connReq{op: "start", task: writers[i]}
			}
			capacity -= abs(target - startedWriters)
			startedWriters = target
		}

		spawned.Add(int64(conf.LimitAddedPerTick - capacity))

		// Check if dynamicConf was updated
		newDynamicConf, err := dynamicConfMngr.TickAndCheckUpdatesAsync(now)
		if err != nil {
			// log the error for visibility, but do not kill the service because it is likely an intermitent network issue. Just try again later.
			logger(logging.Error, fmt.Sprintf("DynamicConf: Load error: %s", err))

		} else if newDynamicConf != nil {
			dynamicConf = newDynamicConf
			logger(logging.Info, fmt.Sprintf("DynamicConf: Updated: %+v", dynamicConf))

			dynamicLogger.UpdateLevelStr(dynamicConf.LogLevel)
			dynamicLoggerReg.UpdateLevelStr(dynamicConf.LogLevelReg)
			msgFactory.Update(dynamicConf)

			// address mapping needs to be updated
			if atomic.SwapInt32(&numAddrs, int32(dynamicConf.Addrs)) != numAddrs {
				staleWriters = startedWriters
				staleListeners = startedListeners
			}
			if atomic.SwapInt32(&numAddrsOffset, int32(dynamicConf.Offset)) != numAddrsOffset {
				staleWriters = startedWriters
				staleListeners = startedListeners
			}
			desiredListeners = dynamicConf.Listeners
			desiredWriters = dynamicConf.Writers
		}
	}
}

func newAddrSpread(baseAddr stream.Address, i int, numAddrs int32, numAddrsOffset int32) stream.Address {
	nAddr := int(numAddrsOffset) + i%int(numAddrs) // spread evently
	addrWithFilter, err := baseAddr.WithFilter(addrFilter, strconv.Itoa(nAddr))
	if err != nil {
		panic("Could not set address filter " + addrFilter + "=" + strconv.Itoa(nAddr))
	}
	return addrWithFilter
}

func buildMetricsTracker(envConf config.TargetENVConf, logger logging.Function) (metrics.Tracker, time.Duration, error) {
	const namespace = "eml_loadtest."
	if !envConf.UseDatadog {
		tracker, err := logged.NewTracker(namespace, envConf, logging.Debug, logger)
		return tracker, 2 * time.Second, err
	}
	tracker, err := datadog.NewTracker("localhost", namespace, envConf)
	return tracker, 10 * time.Second, err
}

func exit(mgr lifecycle.Manager, logger logging.Function) {
	logger(logging.Info, "Exit: Shutdown")
	if err := mgr.ExecuteAll(); err != nil {
		logger(logging.Error, "Exit: Shutdown error:", err)
	}
}

func min(x, y int) int {
	if x < y {
		return x
	}
	return y
}

func max(x, y int) int {
	if x > y {
		return x
	}
	return y
}

func abs(x int) int {
	if x < 0 {
		return -x
	}
	return x
}

func buildAuthSource(conf config.FlagsEnvConfig) stream.AuthSource {
	if conf.GreeterClientAuth.IsEmpty() {
		return func() stream.AuthRequest { return empty.NewRequest() }
	}
	src, err := s2s.NewAuthSource(conf.GreeterClientAuth, "eml-loadtest", "eml-greeter")
	if err != nil {
		panic("Unable to create auth source: " + err.Error())
	}
	return src
}
