package main

import (
	"fmt"
	"math"
	"math/rand"
	"net/http"
	"os"
	"time"

	"github.com/mediocregopher/radix/v3"
	"github.com/rs/zerolog"
	"github.com/rs/zerolog/log"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/dynamodb"

	"code.justin.tv/devhub/e2topics/config"
	"code.justin.tv/devhub/e2topics/e2topics"
)

const (
	mainLoopTickMillis = 200 // milliseconds per main loop tick
)

func main() {
	rand.Seed(time.Now().UnixNano())
	conf := config.ParseGoArgConfig()

	if conf.IsLocal() {
		log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) // Pretty print logs
	}

	var stats e2topics.Statter
	if conf.IsLocal() {
		stats = &e2topics.NullStatter{} // no stats
	} else {
		stats = e2topics.NewTelemetryStatter(conf)
	}

	log.Info().Msgf("Started with conf: %+v", conf)

	// Dynamic Configuration
	dynamicConfMngr := config.MustNewDynamicConfigMngr(conf.DynamicConfS3Bucket, conf.DynamicConfFile)
	dynamicConf := dynamicConfMngr.MustLoad()
	dynamicConfMngr.LoadedLast = dynamicConf // remember when re-loading with TickAndCheckUpdatesAsync
	log.Info().Msgf("DynamicConf: Loaded: %+v", dynamicConf)
	setGlobalLogLevel(dynamicConf.LogLevel)
	log.Info().Msgf("PubsubType: %s", dynamicConf.PubsubType)

	// Redis
	poolSize := 20 // number of connections to keep open for each Redis Shard (total = poolSize * clusterSize)
	poolFunc := func(network, addr string) (radix.Client, error) {
		return radix.NewPool(
			network, // "tcp"
			addr,    // Redis Shard address
			poolSize,

			radix.PoolWithTrace(e2topics.NewRedisPoolMetrics(stats)),
			radix.PoolOnEmptyErrAfter(200*time.Millisecond), // wait for available connections in the pool, or return error radix.ErrEmptyPool ("connection pool is empty")
			radix.PoolRefillInterval(20*time.Second),        // slowly add a new connections if there are no available connections the pool, but not too many because the pool size should be enough
		)
	}
	redisCli, err := radix.NewCluster(conf.RedisClusterAddrsList(), radix.ClusterPoolFunc(poolFunc))
	if err != nil {
		logErrorAndExitf(err, stats, "radix.NewCluster: Unable to connect")
	}
	redisBatch := e2topics.NewRedisBatchLoader(redisCli, stats)

	err = redisCli.Do(radix.Cmd(nil, "PING"))
	if err != nil {
		logErrorAndExitf(err, stats, "redisCli.Do: PING: Unable to connect")
	}

	// AWS DynamoDB
	sess := session.Must(session.NewSessionWithOptions(session.Options{
		Config: aws.Config{
			Region:     aws.String("us-west-2"),
			HTTPClient: &http.Client{Timeout: 10 * time.Second},
		},
	}))
	awsConfig := aws.NewConfig()
	if conf.IsLocal() {
		awsConfig = awsConfig.
			WithEndpoint("http://localhost:8000").
			WithCredentials(credentials.NewSharedCredentials("", "local")) // use the [local] profile as defined in README.md
	}
	dynamodbCli := dynamodb.New(sess, awsConfig)
	dynamoBatch := e2topics.NewDynamoBatchLoader(conf.DynamoDBTable, dynamodbCli)

	// Lock incremental taskNumber, that is relative to other tasks running on the same test instance
	taskNumberLockExpire := time.Duration(dynamicConf.TaskNumberExpire) * time.Second
	taskNumber := lockTaskNumber(redisCli, stats, dynamicConf.TaskNumberKey, dynamicConf.TaskCount, taskNumberLockExpire, 3)
	taskNumberKey := dynamicConf.TaskNumberKey
	go func() {
		for range time.Tick(taskNumberLockExpire / 2) { // refresh again half way through next expiration
			mustRefreshTaskNumberLock(redisCli, stats, taskNumberKey, taskNumber, taskNumberLockExpire)
		}
	}()

	// Distribute channel publishers on this task.
	pubsDistribution := &PubsDistribution{
		Enabled:       dynamicConf.MsgsPerMinute > 0 && dynamicConf.MsgLen > 0,
		Channels:      dynamicConf.Channels,
		ChannelOffset: dynamicConf.ChannelOffset,
		TaskCount:     dynamicConf.TaskCount,
		TaskNumber:    taskNumber,
	}
	publishers := []*e2topics.Publisher{}

	// Distribute viewers on this task.
	subsDistribution := &SubsDistribution{
		Channels:             dynamicConf.Channels,
		ChannelOffset:        dynamicConf.ChannelOffset,
		AvgViewersPerChannel: dynamicConf.AvgViewersPerChannel,
		TaskCount:            dynamicConf.TaskCount,
		subs:                 map[int]bool{},
	}
	subscribers := []*e2topics.Subscriber{}

	// Main loop
	c := time.Tick(mainLoopTickMillis * time.Millisecond)
	for now := range c {

		// Check if dynamicConf was updated
		newDynamicConf, err := dynamicConfMngr.TickAndCheckUpdatesAsync(now)
		if err != nil {
			log.Error().Err(err).Msg("DynamicConf: Load error") // log the error for visibility, but do not kill the service because it is likely an intermitent network issue. Just try again later.
		} else if newDynamicConf != nil {
			dynamicConf = newDynamicConf
			log.Info().Msgf("DynamicConf: Updated: %+v", dynamicConf)
			setGlobalLogLevel(dynamicConf.LogLevel)

			// Allow to dynamically add more channels by adding more tasks and more channels:
			// e.g. if Channels = 100, TaskCount=10, then you can update Channels = 110, TaskCount=11
			pubsDistribution.Channels = dynamicConf.Channels
			pubsDistribution.TaskCount = dynamicConf.TaskCount
			subsDistribution.Channels = dynamicConf.Channels
			subsDistribution.TaskCount = dynamicConf.TaskCount

			// Allow to dynamically adjust publish speed and size.
			// This allows to calm down traffic to see if things would recover with backoffs or circuit breakers naturally.
			for _, p := range publishers {
				p.SetPublishInterval(time.Minute / time.Duration(dynamicConf.MsgsPerMinute))
				p.SetMsgLen(dynamicConf.MsgLen)
			}
		}

		// New connections allowed per tick, to avoid adding all at once which would affect latency metrics
		newConns := math.Ceil(float64(dynamicConf.NewConnsPerSeccond) * float64(mainLoopTickMillis) / 1000.0)
		for {
			// Connect Publisher
			nextPubChannel := pubsDistribution.NextChannel()
			hasMorePubChannels := nextPubChannel != ""
			if hasMorePubChannels {
				pub := &e2topics.Publisher{
					RedisCli:        redisCli,
					RedisBatch:      redisBatch,
					DynamodbCli:     dynamodbCli,
					DynamodbTbl:     conf.DynamoDBTable,
					DynamoBatch:     dynamoBatch,
					PubsubType:      dynamicConf.PubsubType,
					Channel:         nextPubChannel,
					MsgLen:          dynamicConf.MsgLen,
					PublishInterval: time.Minute / time.Duration(dynamicConf.MsgsPerMinute),
					OnMsgSent: func(p *e2topics.Publisher, msg *e2topics.Message, d time.Duration, err error) {
						if err != nil {
							stats.Inc("PublishError", 1)
							log.Error().Err(err).Dur("durMs", d).Str("channel", p.Channel).Msg("Publish message error")
						} else {
							stats.Inc("PublishSuccess", 1)
							stats.Inc("PublishBytes", p.MsgLen)
							stats.Duration("PublishLatency", d)
							log.Trace().Str("channel", p.Channel).Msg("Publish message OK")
						}
					},
				}
				pub.StartPublishing()
				stats.Inc("PublisherStarted", 1)
				publishers = append(publishers, pub)
				newConns -= 1
			}

			// Connect Subscriber
			nextSubChannel := subsDistribution.NextChannel()
			hasMoreSubChannels := nextSubChannel != ""
			if hasMoreSubChannels {
				sub := &e2topics.Subscriber{
					RedisCli:    redisCli,
					RedisBatch:  redisBatch,
					DynamodbCli: dynamodbCli,
					DynamodbTbl: conf.DynamoDBTable,
					DynamoBatch: dynamoBatch,
					PubsubType:  dynamicConf.PubsubType,
					Channel:     nextSubChannel,
					OnMsgRead: func(s *e2topics.Subscriber, msg *e2topics.Message, d time.Duration, err error) {
						if err != nil {
							stats.Inc("SubscribeError", 1)
							log.Error().Err(err).Dur("durMs", d).Str("channel", s.Channel).Msg("Read message error")
						} else {
							stats.Inc("SubscribeSuccess", 1)
							stats.Duration("SubscribeLatency", d)
							rountripLatency := time.Since(msg.SentAt)
							log.Trace().Str("channel", s.Channel).Dur("latencyMs", rountripLatency).Msg("Read message OK")
							if rountripLatency < 10*time.Second { // avoid zero time values or errors skewing the latency metric
								stats.Duration("MsgRoundtripLatency", rountripLatency)
							}
						}
					},
				}
				sub.StartReading()
				stats.Inc("SubscriberStarted", 1)
				subscribers = append(subscribers, sub)
				newConns -= 1
			}

			if newConns <= 0 || (!hasMorePubChannels && !hasMoreSubChannels) {
				break
			}
		}

		// Track more stats
		stats.Gauge("GaugePublishers", len(publishers))
		stats.Gauge("GaugeSubscribers", len(subscribers))
	}
}

// -------
// Helpers
// -------

// Since there's only one publiser per channel and they would be equaly distributed from the load balancer,
// we just distribute across tasks at equal parts. For example, with 200 channels over 10 tasks,
// task 0 publishes on channels 0-19, task 1 publishes on channels 20-39, and so on.
type PubsDistribution struct {
	Enabled       bool
	Channels      int // can be updated with more channel at any time to add more subs
	ChannelOffset int
	TaskCount     int
	TaskNumber    int

	channel int
}

func (d *PubsDistribution) NextChannel() string {
	if !d.Enabled {
		return ""
	}

	limit := d.Channels / d.TaskCount              // split publishers evenly across tasks
	offset := d.ChannelOffset + d.TaskNumber*limit // first channel for this task

	if d.channel == 0 {
		d.channel = offset // start from the offset for this TaskNumber
	}
	if d.channel >= offset+limit {
		return "" // all channels returned
	}

	d.channel++
	channel := d.channel - 1
	log.Debug().Int("channel", channel).Msg("Started Publisher")
	return channelRedisKey(channel)
}

// Viewers are evenly connected to edge instances. Only instances with at least one viewer connected
// on a given channel need to subscribe to messages on that channel. To calculate channels with viewers
// connected to this instance, we calculate how many viewers connect on this instance, and then assign
// them to random channels, with a higher probability to cluster on the initial channels.
// This also means you should be careful when specifying the number of channels, viewers per channel,
// and number of tasks on the dynamic config. Instances can handle up to 17K connections.
// Make sure to keep it realistic. For example, 40K channels with avg 20 viewers per channel, on 20 Tasks,
// means a total of 40K * 21 = 840K users, split on 20 tasks = 42K users per task, which is too much.
// The number of tasks should be 50, because 840K/50 = 16.8K. With 20 viewers per channel.
type SubsDistribution struct {
	Channels             int // can be updated with more channel at any time to add more subs
	ChannelOffset        int
	AvgViewersPerChannel int
	TaskCount            int

	viewer int
	subs   map[int]bool // remember already visited channels
}

func (d *SubsDistribution) NextChannel() string {
	totalViewers := d.Channels * d.AvgViewersPerChannel
	viewersOnThisTask := totalViewers / d.TaskCount
	if d.viewer >= viewersOnThisTask {
		return "" // all channels returned
	}
	d.viewer++

	// Each viewer is assigned to a random channel,
	// with higher probability of joining a lower channel number,
	// to simulate the real-world tendnecy of having some big channels and many small channels.
	r := rand.Float64() // random 0-1 value with uniform distribution
	r = r * r * r       // convert to exponential distribution, so most viewers join the lower channels
	channel := d.ChannelOffset + int(float64(d.Channels)*r)

	if d.subs[channel] { // if this task already has a subscrition on this channel
		return d.NextChannel() // skip; this viewer would share the same subscription.
	}
	d.subs[channel] = true // remember this channel

	log.Debug().Int("channel", channel).Int("viewer", d.viewer-1).Msg("Started Subscriber")
	return channelRedisKey(channel)
}

func channelRedisKey(i int) string {
	return fmt.Sprintf("ch%d", i)
}

func logErrorAndExitf(err error, stats e2topics.Statter, msg string, msgf ...interface{}) {
	stats.Inc("TaskErrorAndExit", 1)
	log.Error().Err(err).Msgf(msg, msgf...)
	os.Exit(1)
}

// Task number uniquely identifies the task in a sorted position starting at zero.
// This number is used to properly distribute channels across all running instances.
// Uses redis to atomically increment the number, and expires in 2 minutes (make sure to wait 2 minutes between last task starts).
func lockTaskNumber(redisCli radix.Client, stats e2topics.Statter, taskNumberKey string, taskCount int, expire time.Duration, retries int) int {
	if retries <= 0 {
		logErrorAndExitf(fmt.Errorf("no more retries"), stats, "lockTaskNumber: unable to lock TaskNumber")
	}

	taskNumber := 0
	for {
		lockKey := taskNumberLockKey(taskNumberKey, taskNumber)
		var lockValue int
		err := redisCli.Do(radix.Cmd(&lockValue, "INCR", lockKey))
		if err != nil {
			log.Warn().Err(err).Msgf("Redis: Unable to increment lockKey: %s. Retries left: %d", lockKey, retries-1)
			time.Sleep(expire) // wait for expiration time and try again from the beginning
			return lockTaskNumber(redisCli, stats, taskNumberKey, taskCount, expire, retries-1)
		}
		log.Trace().Int("taskNumber", taskNumber).Str("taskNumberKey", taskNumberKey).Str("lockKey", lockKey).Int("lockValue", lockValue).Msg("TaskNumber lock INCR")

		// Lock aquired
		if lockValue == 1 {
			log.Info().Int("taskNumber", taskNumber).Str("taskNumberKey", taskNumberKey).Msgf("TaskNumber lock aquired for task %d (%d/%d)", taskNumber, taskNumber+1, taskCount)
			mustRefreshTaskNumberLock(redisCli, stats, taskNumberKey, taskNumber, expire) // Add expiration to the lock, to make sure it is eventually released
			return taskNumber

		} else {
			if taskNumber >= taskCount-1 { // check if this is the last lock available
				log.Warn().Msgf("lockTaskNumber: all locks were taken. Retries left: %d", retries-1)
				mustRefreshTaskNumberLock(redisCli, stats, taskNumberKey, taskNumber, expire)
				time.Sleep(expire + 2*time.Second) // wait for expiration time and try again from the beginning
				return lockTaskNumber(redisCli, stats, taskNumberKey, taskCount, expire, retries-1)
			}

			// Try to lock a higher number
			taskNumber += 1
		}
	}

}

func taskNumberLockKey(taskNumberKey string, taskNumber int) string {
	return fmt.Sprintf("%s-lock%d", taskNumberKey, taskNumber)
}

func mustRefreshTaskNumberLock(redisCli radix.Client, stats e2topics.Statter, taskNumberKey string, taskNumber int, expire time.Duration) {
	err := refreshTaskNumberLock(redisCli, taskNumberKey, taskNumber, expire)
	if err != nil {
		time.Sleep(1 * time.Second) // retry once
		err = refreshTaskNumberLock(redisCli, taskNumberKey, taskNumber, expire)
		if err != nil {
			logErrorAndExitf(err, stats, "Redis: Unable to expire lockKey: %s-lock%d. Please delete the key or change dynamicConf.TaskNumberKey and try again", taskNumberKey, taskNumber)
		}
	}
}

func refreshTaskNumberLock(redisCli radix.Client, taskNumberKey string, taskNumber int, expire time.Duration) error {
	log.Trace().Int("taskNumber", taskNumber).Str("taskNumberKey", taskNumberKey).Msgf("TaskNumber lock expires in %v", expire)
	lockKey := taskNumberLockKey(taskNumberKey, taskNumber)
	return redisCli.Do(radix.FlatCmd(nil, "EXPIRE", lockKey, expire.Seconds()))
}

func setGlobalLogLevel(logLevelStr string) {
	if logLevelStr == "" {
		logLevelStr = "info" // default
	}

	logLevel, err := zerolog.ParseLevel(logLevelStr)
	if err != nil {
		log.Error().Err(err).Msgf("DynamicConf: Invalid LogLevel: %s", logLevelStr)
		return
	}

	zerolog.SetGlobalLevel(logLevel) // it may be the same one as before
}
