package main

import (
	"bytes"
	"context"
	"errors"
	"flag"
	"fmt"
	"log"
	"math/bits"
	"math/rand"
	"net"
	"sync"
	"sync/atomic"
	"time"

	"code.justin.tv/video/metrics-middleware/v2/operation"
	"github.com/go-redis/redis"
	"golang.org/x/sync/errgroup"
)

// Pair this with redis-cli executing a series of slow scripts to stall the
// Redis server's event loop and trigger the clients to disconnect.
//
//     $ date ; for i in {1..20} ; do env time redis-cli -p 3000 --eval <(echo -n -e 'local i = tonumber(ARGV[1])\nwhile (i > 0) do\ni = i-1\nend\nreturn 42\n') , 20000000 >/dev/null ; done

func main() {
	addr := flag.String("addr", "127.0.0.1:6379", "Address of Redis server")
	cmdCost := flag.Int64("cmd-cost", 1000, "Linearly-scaling cost of each Redis command")
	workers := flag.Int("workers", 20, "Count of Redis-consuming goroutines")
	poolSize := flag.Int("pool", 20, "Size of Redis connection pool")
	poolTimeout := flag.Duration("pool-timeout", 1*time.Second, "How long to wait for a connection")
	testDuration := flag.Duration("duration", 10*time.Second, "Duration of the test")
	readTimeout := flag.Duration("read-timeout", 100*time.Millisecond, "Timeout for reading Redis responses")
	pause := flag.Duration("pause", 0, "Delay between Redis commands")
	clientCount := flag.Int("client-count", 10, "Number of separate Redis clients")
	flag.Parse()

	rand.Seed(time.Now().UnixNano())

	ctx := context.Background()

	var mayLog int32
	go func() {
		for range time.NewTicker(100 * time.Millisecond).C {
			atomic.StoreInt32(&mayLog, 0)
		}
	}()

	// each second, print
	// what is the current connection count
	// how many dials since the last report
	// how many reads
	// how many errors of each code
	// how many Redis operations
	// how many Redis errors
	// latency distribution of reads

	var counts opCounts
	go func() {
		for range time.NewTicker(1 * time.Second).C {
			c2 := counts.clone()
			counts.sub(c2)
			log.Printf("%s", c2.String())
		}
	}()
	defer log.Printf("%s", counts.String())

	writeTimeout := *readTimeout * 2

	starter := &operation.Starter{
		OpMonitors: []operation.OpMonitor{
			&observer{counts: &counts},
		},
	}

	opts := &redis.ClusterOptions{
		Addrs:        []string{*addr},
		DialTimeout:  writeTimeout,
		ReadTimeout:  *readTimeout,
		WriteTimeout: writeTimeout,
		PoolTimeout:  *poolTimeout,
		PoolSize:     *poolSize,

		ReadOnly: true,
	}

	lim := &limiter{
		c: &aimd{
			addIncrease: 1.0,
			mulDecrease: 0.7,
			minTPS:      5.0,
			maxTPS:      1e6,

			tps: 5.0,
		},
	}
	go func() {
		for range time.NewTicker(1 * time.Second).C {
			log.Printf("tps=%0.1f", lim.Rate())
		}
	}()

	proc := &processor{
		limiter: lim,
		starter: starter,
	}

	clients := make([]*redis.ClusterClient, *clientCount)
	for i := range clients {
		rc := redis.NewClusterClient(opts)
		rc.WrapProcess(proc.wrapper)
		rc = rc.WithContext(ctx)
		clients[i] = rc
	}

	ctx, cancel := context.WithTimeout(ctx, *testDuration)
	defer cancel()
	eg, ctx := errgroup.WithContext(ctx)
	for i := 0; i < *workers; i++ {
		rc := clients[i%len(clients)]
		eg.Go(func() error {
			ctx := ctx
			rc := rc.WithContext(ctx)
			for i := 0; ; i++ {
				if ctx.Err() != nil {
					return nil
				}
				if *pause > 0 {
					if i == 0 {
						// uniform random jitter for first request
						time.Sleep(time.Duration(rand.Int63n(int64(*pause))))
					}
					// poisson distribution for requests in general
					time.Sleep(time.Duration(rand.ExpFloat64() * float64(*pause)))
				}
				var err error
				if *cmdCost > 0 {
					_, err = rc.Eval(`
local i = tonumber(ARGV[1])
while (i > 0) do
	i = i-1
end
return 42
				`, nil, *cmdCost).Result()
				} else {
					key := fmt.Sprintf("k%d", rand.Int63())
					_, err = rc.Get(key).Result()
					if err == redis.Nil {
						err = nil
					}
				}
				counts.addRedisOp()
				if err != nil {
					counts.addRedisErr()
					if atomic.CompareAndSwapInt32(&mayLog, 0, 1) {
						log.Printf("cmd error %v", err)
					}
				}
			}
			return nil
		})
	}
	err := eg.Wait()
	if err != nil {
		log.Fatalf("%v", err)
	}

	// log.Printf("dials=%d cmd-errors=%d reads=%d read-errors=%d read-success-rate=%0.4f",
	// 	atomic.LoadInt64(&dialCount),
	// 	atomic.LoadInt64(&cmdErrCount),
	// 	atomic.LoadInt64(&readCount),
	// 	atomic.LoadInt64(&readErrCount),
	// 	1-(float64(atomic.LoadInt64(&readErrCount))/float64(atomic.LoadInt64(&readCount))),
	// )
}

type observer struct {
	counts *opCounts
}

func (o *observer) MonitorOp(ctx context.Context, name operation.Name) (context.Context, *operation.MonitorPoints) {
	points := &operation.MonitorPoints{
		End: func(report *operation.Report) {
			// if name.Method == "Read" {
			if name.Method == "get" || name.Method == "eval" {
				o.counts.addRead(report.Status.Code, report.EndTime.Sub(report.StartTime))
			}
			// duration := report.EndTime.Sub(report.StartTime)
			// fmt.Printf("op=%-5s status=%d duration=%0.6fs\n", name.Method, report.Status.Code, duration.Seconds())
		},
	}
	return ctx, points
}

type wrappedConn struct {
	net.Conn
	starter *operation.Starter
	fail    bool
}

func (c *wrappedConn) Write(b []byte) (int, error) {
	ctx := context.Background()
	ctx, op := c.starter.StartOp(ctx, operation.Name{Group: "net", Method: "Write"})

	n, err := c.Conn.Write(b)

	if err != nil {
		status := 2 // UNKNOWN
		if ne, ok := err.(net.Error); ok && ne.Timeout() {
			status = 4 // DEADLINE_EXCEEDED
		}
		op.SetStatus(operation.Status{Code: int32(status)})
	}

	op.End()
	return n, err
}

func (c *wrappedConn) Read(b []byte) (int, error) {
	ctx := context.Background()
	ctx, op := c.starter.StartOp(ctx, operation.Name{Group: "net", Method: "Read"})

	n, err := c.Conn.Read(b)

	if err != nil {
		status := 2 // UNKNOWN
		if ne, ok := err.(net.Error); ok && ne.Timeout() {
			status = 4 // DEADLINE_EXCEEDED
		}
		op.SetStatus(operation.Status{Code: int32(status)})
	}

	op.End()

	if c.fail {
		n, err = 0, timeoutError{}
	}
	return n, err
}

type timeoutError struct{}

func (timeoutError) Error() string   { return "" }
func (timeoutError) Timeout() bool   { return true }
func (timeoutError) Temporary() bool { return true }

var _ error = timeoutError{}
var _ net.Error = timeoutError{}

type opCounts struct {
	dials       int64
	redisOps    int64
	redisErrs   int64
	readLatency [65]int64
	readCodes   [14]int64
}

func (c *opCounts) addDial() { atomic.AddInt64(&c.dials, 1) }

func (c *opCounts) addRead(code int32, latency time.Duration) {
	if code < 0 || int(code) >= len(c.readCodes) {
		code = 2 // UNKNOWN
	}
	atomic.AddInt64(&c.readCodes[int(code)], 1)

	if latency < 0 {
		latency = 0
	}
	atomic.AddInt64(&c.readLatency[64-bits.LeadingZeros64(uint64(latency))], 1)
}

func (c *opCounts) addRedisOp()  { atomic.AddInt64(&c.redisOps, 1) }
func (c *opCounts) addRedisErr() { atomic.AddInt64(&c.redisErrs, 1) }

func (c *opCounts) String() string {
	c = c.clone()
	var buf bytes.Buffer
	fmt.Fprintf(&buf, "dial=%d ", c.dials)
	fmt.Fprintf(&buf, "redisOps=%d ", c.redisOps)
	fmt.Fprintf(&buf, "redisErrs=%d ", c.redisErrs)
	fmt.Fprintf(&buf, "read0=%d read2=%d read4=%d read8=%d ", c.readCodes[0], c.readCodes[2], c.readCodes[4], c.readCodes[8])

	var readCount int64
	for _, count := range c.readLatency {
		readCount += count
	}
	p99Count := readCount - readCount/100
	p90Count := readCount - 10*readCount/100
	p50Count := readCount - 50*readCount/100
	for i, count := range c.readLatency {
		p99Count -= count
		if p99Count <= 0 {
			fmt.Fprintf(&buf, "readp99=%0.3fms ", float64(uint(1<<uint(i)))/1e6)
			break
		}
	}
	for i, count := range c.readLatency {
		p90Count -= count
		if p90Count <= 0 {
			fmt.Fprintf(&buf, "readp90=%0.3fms ", float64(uint(1<<uint(i)))/1e6)
			break
		}
	}
	for i, count := range c.readLatency {
		p50Count -= count
		if p50Count <= 0 {
			fmt.Fprintf(&buf, "readp50=%0.3fms ", float64(uint(1<<uint(i)))/1e6)
			break
		}
	}

	return buf.String()
}

func (c *opCounts) clone() *opCounts {
	c2 := &opCounts{
		dials:     atomic.LoadInt64(&c.dials),
		redisOps:  atomic.LoadInt64(&c.redisOps),
		redisErrs: atomic.LoadInt64(&c.redisErrs),
	}
	for i := range c.readLatency {
		c2.readLatency[i] = atomic.LoadInt64(&c.readLatency[i])
	}
	for i := range c.readCodes {
		c2.readCodes[i] = atomic.LoadInt64(&c.readCodes[i])
	}
	return c2
}

func (c *opCounts) sub(c2 *opCounts) {
	atomic.AddInt64(&c.dials, -1*atomic.LoadInt64(&c2.dials))
	atomic.AddInt64(&c.redisOps, -1*atomic.LoadInt64(&c2.redisOps))
	atomic.AddInt64(&c.redisErrs, -1*atomic.LoadInt64(&c2.redisErrs))
	for i := range c.readLatency {
		atomic.AddInt64(&c.readLatency[i], -1*atomic.LoadInt64(&c.readLatency[i]))
	}
	for i := range c.readCodes {
		atomic.AddInt64(&c.readCodes[i], -1*atomic.LoadInt64(&c.readCodes[i]))
	}
}

//

type processor struct {
	limiter *limiter
	starter *operation.Starter
}

func (p *processor) wrapper(oldProcess func(cmd redis.Cmder) error) func(cmd redis.Cmder) error {
	return func(cmd redis.Cmder) error {
		ctx := context.Background()

		ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
		defer cancel()

		ctx, op := p.starter.StartOp(ctx, operation.Name{Group: "redis", Method: cmd.Name(), Kind: operation.KindClient})
		defer op.End()

		var status operation.Status
		defer func() { op.SetStatus(status) }()

		if !p.limiter.Allow() {
			status.Code = 8 // RESOURCE_EXHAUSTED
			return errors.New("rate limited")
		}
		defer func() {
			if status.Code == 0 {
				p.limiter.MarkSuccess()
			} else {
				p.limiter.MarkFailure()
			}
		}()

		err := oldProcess(cmd)
		if err != nil && err != redis.Nil {
			if ne, ok := err.(net.Error); ok && ne.Timeout() {
				status.Code = 4 // DEADLINE_EXCEEDED
			} else {
				status.Code = 2 // UNKNOWN
			}
		}

		return err
	}
}

//

type control interface {
	Rate() float64
	MarkSuccess()
	MarkFailure()
}

type aimd struct {
	addIncrease float64
	mulDecrease float64
	minTPS      float64
	maxTPS      float64

	mu  sync.Mutex
	tps float64
}

var _ control = (*aimd)(nil)

func (v *aimd) Rate() float64 {
	v.mu.Lock()
	tps := v.tps
	v.mu.Unlock()
	return tps
}

func (v *aimd) MarkSuccess() {
	v.mu.Lock()
	v.tps += v.addIncrease
	if v.tps > v.maxTPS {
		v.tps = v.maxTPS
	}
	v.mu.Unlock()
}

func (v *aimd) MarkFailure() {
	v.mu.Lock()
	v.tps *= v.mulDecrease
	if v.tps < v.minTPS {
		v.tps = v.minTPS
	}
	v.mu.Unlock()
}

type limiter struct {
	c control

	mu       sync.Mutex
	fillTime time.Time
	tokens   float64
}

var _ control = (*limiter)(nil)

func (l *limiter) Rate() float64 { return l.c.Rate() }
func (l *limiter) MarkSuccess()  { l.c.MarkSuccess() }
func (l *limiter) MarkFailure()  { l.c.MarkFailure() }

func (l *limiter) Allow() bool {
	l.mu.Lock()
	defer l.mu.Unlock()

	next := l.tokens - 1.0
	if next < 0.0 {
		l.refillLocked(time.Now())
		next = l.tokens - 1.0
		if next < 0.0 {
			return false
		}
	}
	l.tokens = next

	return true
}

func (l *limiter) refillLocked(now time.Time) {
	const maxDelay = 1 * time.Second

	delay := now.Sub(l.fillTime)
	if delay < 0 {
		delay = 0
	}
	if delay > maxDelay || l.fillTime.IsZero() {
		delay = maxDelay
	}

	rate := l.Rate()
	tokens := rate * delay.Seconds()

	l.fillTime = now
	l.tokens += tokens
}
