package main

import (
	"fmt"
	"log"
	"regexp"
	"strings"
	"sync"

	"code.justin.tv/rhys/nursery/internal/_vendor/trace"
)

func eventString(ev *trace.Event, stack bool) string {
	var b strings.Builder
	fmt.Fprintf(&b, "%s\n", ev)
	if stack {
		for _, frame := range ev.Stk {
			fmt.Fprintf(&b, "  %x %s %s:%d\n", frame.PC, frame.Fn, frame.File, frame.Line)
		}
	}
	return b.String()
}

type region struct {
	Kind   string
	Events []*trace.Event
}

func hasFrameExact(stk []*trace.Frame, fn string) bool {
	for _, frame := range stk {
		if trimVendor(frame.Fn) == fn {
			return true
		}
	}
	return false
}

func hasFramePrefix(stk []*trace.Frame, fn string) bool {
	for _, frame := range stk {
		if strings.HasPrefix(trimVendor(frame.Fn), fn) {
			return true
		}
	}
	return false
}

func trimVendor(fn string) string {
	return globalProgram.trimVendor(fn)
}

func hasStackRe(stk []*trace.Frame, specs ...string) bool {
	match, err := globalProgram.hasStackRe(stk, specs...)
	if err != nil {
		panic(err)
	}
	return match
}

var globalProgram program

type program struct {
	mu   sync.Mutex
	re   map[string]regexpCompile
	trim map[string]string
}

func (p *program) trimVendor(fn string) string {
	p.mu.Lock()
	defer p.mu.Unlock()
	if p.trim == nil {
		p.trim = make(map[string]string)
	}
	saved, ok := p.trim[fn]
	if !ok {
		saved = fn
		if i := strings.LastIndex(saved, "/vendor/"); i >= 0 {
			saved = saved[i+len("/vendor/"):]
		}
		saved = strings.TrimPrefix(saved, "vendor/")
		p.trim[fn] = saved
	}
	return saved
}

type regexpCompile struct {
	re  *regexp.Regexp
	err error
}

func (p *program) compile(expr string) (*regexp.Regexp, error) {
	p.mu.Lock()
	defer p.mu.Unlock()
	if p.re == nil {
		p.re = make(map[string]regexpCompile)
	}
	saved, ok := p.re[expr]
	if !ok {
		saved.re, saved.err = regexp.Compile(expr)
		p.re[expr] = saved
	}
	return saved.re, saved.err
}

func (p *program) mustCompile(expr string) *regexp.Regexp {
	re, err := p.compile(expr)
	if err != nil {
		panic(fmt.Errorf("mustCompile: %w", err))
	}
	return re
}

func (p *program) hasStackRe(stk []*trace.Frame, specs ...string) (bool, error) {
	var (
		any    = p.mustCompile(`$any`)    // sentinel: zero or more stack frames
		anyOne = p.mustCompile(`$anyOne`) // sentinel: a single stack frame
	)

	res := make([]*regexp.Regexp, 0, len(specs))
	for _, spec := range specs {
		switch spec {
		case "$any":
			res = append(res, any)
		case "$anyOne":
			res = append(res, anyOne)
		default:
			re, err := p.compile(spec)
			if err != nil {
				return false, fmt.Errorf("could not compile regexp %q: %w", spec, err)
			}
			res = append(res, re)
		}
	}

	// Run the NFA, starting at the first matcher
	prev := []int{0}
	for i := len(stk) - 1; i >= 0; i-- {
		// walk the stack starting at the root
		frame := stk[i]
		fn := trimVendor(frame.Fn)
		var next []int
		for _, state := range prev {
			if state >= len(res) {
				continue
			}
			switch re := res[state]; re {
			case any:
				next = append(next, state, state+1)
			case anyOne:
				next = append(next, state+1)
			default:
				if re.MatchString(fn) {
					next = append(next, state+1)
				}
			}
		}
		prev = next
	}

	// Check if the NFA reached the terminal state
	for _, state := range prev {
		if state == len(res) {
			return true, nil
		}
	}
	return false, nil
}

func doesRoundTrip(stk []*trace.Frame) bool {
	// We could phrase this as:
	//
	// hasStack("$any", "net/http.(*persistConn).roundTrip", "$any")
	return hasStackRe(stk, "$any", "^net/http...persistConn..roundTrip$", "$any")
	// return hasFrameExact(stk, "net/http.(*persistConn).roundTrip")
}

func doesReadHTTPRequest(stk []*trace.Frame) bool {
	// We could phrase this as:
	//
	// hasStack("net/http.(*conn).serve", "net/http.(*conn).readRequest", "$any")
	// hasStack("net/http.(*conn).serve", "bufio.(*Reader).Peek", "$any")
	return hasStackRe(stk, `^net/http...conn..serve$`, `^net/http...conn..readRequest$`, `$any`) ||
		hasStackRe(stk, `^net/http...conn..serve$`, `^bufio...Reader..Peek$`, `$any`)

	// if len(stk) < 2 {
	// 	return false
	// }
	// if fn := trimVendor(stk[len(stk)-1].Fn); fn != "net/http.(*conn).serve" {
	// 	return false
	// }
	// if fn := trimVendor(stk[len(stk)-2].Fn); fn == "net/http.(*conn).readRequest" {
	// 	return true
	// }
	// if fn := trimVendor(stk[len(stk)-2].Fn); fn == "bufio.(*Reader).Peek" {
	// 	return true
	// }
	// return false
}

func doesNydusReadHTTPRequest(stk []*trace.Frame) bool {
	// We could phrase this as:
	//
	// hasStack("code.justin.tv/video/nydus/core/rawhttp.(*Server).serveConn", "net/http.ReadRequest", "$any")
	// hasStack("code.justin.tv/video/nydus/core/rawhttp.(*Server).serveConn", "bufio.(*Reader).Peek", "$any")
	if len(stk) < 2 {
		return false
	}
	if fn := trimVendor(stk[len(stk)-1].Fn); fn != "code.justin.tv/video/nydus/core/rawhttp.(*Server).serveConn" {
		return false
	}
	if fn := trimVendor(stk[len(stk)-2].Fn); fn == "net/http.ReadRequest" {
		return true
	}
	if fn := trimVendor(stk[len(stk)-2].Fn); fn == "bufio.(*Reader).Peek" {
		return true
	}
	return false
}

func doesRedisConnWrite(stk []*trace.Frame) bool {
	// We could phrase this as:
	//
	// hasStackRe("$any", "^github.com/go-redis/redis/v7/internal/proto...Writer..", $any, "^syscall\.Write$", "$any")
	return hasStackRe(stk,
		`$any`,
		`^github.com/go-redis/redis/v7/internal/proto...Writer..`,
		`$any`,
		`^syscall\.Write$`,
		`$any`)

	// return hasFrameExact(stk, "syscall.Write") &&
	// 	hasFramePrefix(stk, "github.com/go-redis/redis/v7/internal/proto.(*Writer).")
}

func doesRedisConnRead(stk []*trace.Frame) bool {
	// We could phrase this as:
	//
	// hasStackRe("$any", "^github.com/go-redis/redis/v7/internal/proto...Reader..", $any, "^internal/poll...FD..Read$", "$any")
	return hasFrameExact(stk, "internal/poll.(*FD).Read") &&
		hasFramePrefix(stk, "github.com/go-redis/redis/v7/internal/proto.(*Reader).")
}

func doesLibPqConnWrite(stk []*trace.Frame) bool {
	return hasFrameExact(stk, "syscall.Write") &&
		hasFrameExact(stk, "database/sql.(*DB).query") &&
		hasFrameExact(stk, "github.com/lib/pq.(*conn).query")
}

func doesLibPqConnRead(stk []*trace.Frame) bool {
	return hasFrameExact(stk, "internal/poll.(*FD).Read") &&
		hasFrameExact(stk, "database/sql.(*DB).query") &&
		hasFrameExact(stk, "github.com/lib/pq.(*conn).query")
}

func doesHTTPGetConn(stk []*trace.Frame) bool {
	return hasFrameExact(stk, "net/http.(*Transport).getConn")
}

func doesHTTPRoundTrip(stk []*trace.Frame) bool {
	return hasFrameExact(stk, "net/http.(*Transport).roundTrip")
}

func doesDAXConnWrite(stk []*trace.Frame) bool {
	return hasFrameExact(stk, "syscall.Write") &&
		hasFrameExact(stk, "github.com/aws/aws-dax-go/dax/internal/client.(*SingleDaxClient).executeWithContext")
}

func doesDAXConnRead(stk []*trace.Frame) bool {
	return hasFrameExact(stk, "internal/poll.(*FD).Read") &&
		hasFrameExact(stk, "github.com/aws/aws-dax-go/dax/internal/client.(*SingleDaxClient).executeWithContext")
}

type inboundRequestState func(*trace.Event) inboundRequestState

type track struct {
	verbose bool

	flushInbound func([]*trace.Event)

	prev  *trace.Event
	queue []*trace.Event
}

func (t *track) log(format string, v ...interface{}) {
	if t.verbose {
		log.Printf(format, v...)
	}
}

func (t *track) inboundRequestPending(ev *trace.Event) inboundRequestState {
	if doesReadHTTPRequest(ev.Stk) || doesNydusReadHTTPRequest(ev.Stk) {
		t.log("pending->pending    %s", ev)
		t.prev = ev
		return t.inboundRequestPending
	}
	if t.prev != nil && ev.Stk != nil {
		t.log("pending->active     %s", ev)
		t.queue = nil
		t.queue = append(t.queue, ev)
		return t.inboundRequestActive
	}

	t.log("pending??           %s", ev)
	return t.inboundRequestPending
}

func (t *track) inboundRequestActive(ev *trace.Event) inboundRequestState {
	done := func() {
		if t.prev == nil {
			return
		}

		if fn := t.flushInbound; fn != nil {
			fn(t.queue)
		}

		t.prev = nil
		t.queue = nil
	}

	t.queue = append(t.queue, ev)

	if ev.Type == trace.EvGoEnd {
		t.log("active->nil         %s", ev)
		done()
		return nil
	}
	if doesReadHTTPRequest(ev.Stk) || doesNydusReadHTTPRequest(ev.Stk) {
		// finished request (for keepalive connection)
		t.log("active->pending     %s", ev)
		done()
		state := t.inboundRequestPending
		state = state(ev)
		return state
	}

	t.log("active->active      %s", ev)
	return t.inboundRequestActive
}

type trackerState func(*trace.Event) trackerState

type mutexTracker struct {
	verbose bool

	flush func([]*trace.Event)

	queue []*trace.Event
}

func (t *mutexTracker) log(format string, v ...interface{}) {
	if t.verbose {
		log.Printf(format, v...)
	}
}

func (t *mutexTracker) idle(ev *trace.Event) trackerState {
	if ev.Type == trace.EvGoBlockSync && hasFrameExact(ev.Stk, "sync.(*Mutex).Lock") {
		t.log("idle->active   %s", ev)
		t.queue = nil
		t.queue = append(t.queue, ev)
		return t.active
	}

	t.log("idle->idle     %s", ev)
	return t.idle
}

func (t *mutexTracker) active(ev *trace.Event) trackerState {
	done := func() {
		if fn := t.flush; fn != nil {
			if len(t.queue) >= 2 {
				fn(t.queue)
			}
		}

		t.queue = nil
	}

	t.queue = append(t.queue, ev)

	if ev.Type == trace.EvGoStart {
		done()
		return t.idle
	}

	t.log("active->active %s", ev)
	return t.active
}

type redisTracker struct {
	verbose bool

	flush func([]*trace.Event)

	queue []*trace.Event
}

func (t *redisTracker) log(format string, v ...interface{}) {
	if t.verbose {
		log.Printf(format, v...)
	}
}

func (t *redisTracker) idle(ev *trace.Event) trackerState {
	if ev.Type == trace.EvGoSysCall && doesRedisConnWrite(ev.Stk) {
		t.log("idle->active   %s", ev)
		t.queue = nil
		t.queue = append(t.queue, ev)
		return t.active
	}

	t.log("idle->idle     %s", ev)
	return t.idle
}

func (t *redisTracker) active(ev *trace.Event) trackerState {
	done := func() {
		if fn := t.flush; fn != nil {
			if len(t.queue) >= 2 {
				fn(t.queue)
			}
		}

		t.queue = nil
	}

	t.queue = append(t.queue, ev)

	if ev.Type == trace.EvHeapAlloc ||
		ev.Type == trace.EvGoStart {
		return t.active
	}

	switch {
	case ev.Type == trace.EvGoEnd:
		t.queue = t.queue[:len(t.queue)-1]
		t.log("active->nil    %s", ev)
		done()
		return nil
	case ev.Type == trace.EvHeapAlloc || ev.Type == trace.EvGoStart:
		t.log("active->active %s", ev)
		return t.active
	case ev.Type == trace.EvGoPreempt:
		// TODO: backport to others
		// TODO: factor out common parts now that they're more apparent
		t.log("active->active %s", ev)
		return t.active
	case ev.Type == trace.EvGoSysCall || ev.Type == trace.EvGoBlockNet:
		if doesRedisConnRead(ev.Stk) {
			t.log("active->active %s", ev)
			return t.active
		}
	}

	t.queue = t.queue[:len(t.queue)-1]
	t.log("active->idle   %s", ev)
	done()
	return t.idle
}

type libpqTracker struct {
	verbose bool

	flush func([]*trace.Event)

	queue []*trace.Event
}

func (t *libpqTracker) log(format string, v ...interface{}) {
	if t.verbose {
		log.Printf(format, v...)
	}
}

func (t *libpqTracker) idle(ev *trace.Event) trackerState {
	if ev.Type == trace.EvGoSysCall && doesLibPqConnWrite(ev.Stk) {
		t.log("idle->active   %s", ev)
		t.queue = nil
		t.queue = append(t.queue, ev)
		return t.active
	}

	t.log("idle->idle     %s", ev)
	return t.idle
}

func (t *libpqTracker) active(ev *trace.Event) trackerState {
	done := func() {
		if fn := t.flush; fn != nil {
			if len(t.queue) >= 2 {
				fn(t.queue)
			}
		}

		t.queue = nil
	}

	t.queue = append(t.queue, ev)

	if ev.Type == trace.EvHeapAlloc ||
		ev.Type == trace.EvGoStart {
		return t.active
	}

	switch {
	case ev.Type == trace.EvGoEnd:
		t.queue = t.queue[:len(t.queue)-1]
		t.log("active->nil    %s", ev)
		done()
		return nil
	case ev.Type == trace.EvHeapAlloc || ev.Type == trace.EvGoStart:
		t.log("active->active %s", ev)
		return t.active
	case ev.Type == trace.EvGoSysCall || ev.Type == trace.EvGoBlockNet:
		if doesLibPqConnRead(ev.Stk) {
			t.log("active->active %s", ev)
			return t.active
		}
	}

	t.queue = t.queue[:len(t.queue)-1]
	t.log("active->idle   %s", ev)
	done()
	return t.idle
}

type httpClientTracker struct {
	verbose bool

	flush func([]*trace.Event)

	queue []*trace.Event
}

func (t *httpClientTracker) log(format string, v ...interface{}) {
	if t.verbose {
		log.Printf(format, v...)
	}
}

func (t *httpClientTracker) idle(ev *trace.Event) trackerState {
	start := false
	if ev.Type == trace.EvGoCreate && doesHTTPGetConn(ev.Stk) {
		start = true
	}
	if ev.Type == trace.EvGoUnblock {
		for _, fn := range ev.Stk {
			if fn.Fn == "net/http.(*Transport).roundTrip" {
				start = true
				break
			}
			if strings.HasPrefix(fn.Fn, "runtime.") || strings.HasPrefix(fn.Fn, "net/http.") {
				continue
			}
			break
		}
	}

	if start {
		t.log("idle->active   %s", ev)
		t.queue = nil
		t.queue = append(t.queue, ev)
		return t.active
	}

	t.log("idle->idle     %s", ev)
	return t.idle
}

func (t *httpClientTracker) active(ev *trace.Event) trackerState {
	done := func() {
		if fn := t.flush; fn != nil {
			if len(t.queue) >= 2 {
				fn(t.queue)
			}
		}

		t.queue = nil
	}

	t.queue = append(t.queue, ev)

	if ev.Type == trace.EvHeapAlloc ||
		ev.Type == trace.EvGoStart {
		return t.active
	}

	switch {
	case ev.Type == trace.EvGoEnd:
		t.queue = t.queue[:len(t.queue)-1]
		t.log("active->nil    %s", ev)
		done()
		return nil
	case ev.Type == trace.EvHeapAlloc || ev.Type == trace.EvGoStart:
		t.log("active->active %s", ev)
		return t.active
	case ev.Type == trace.EvGoBlockSelect:
		if doesHTTPRoundTrip(ev.Stk) {
			t.log("active->active %s", ev)
			return t.active
		}
	}

	t.queue = t.queue[:len(t.queue)-1]
	t.log("active->idle   %s", ev)
	done()
	return t.idle
}

type daxTracker struct {
	verbose bool

	flush func([]*trace.Event)

	queue []*trace.Event
}

func (t *daxTracker) log(format string, v ...interface{}) {
	if t.verbose {
		log.Printf(format, v...)
	}
}

func (t *daxTracker) idle(ev *trace.Event) trackerState {
	if ev.Type == trace.EvGoSysCall && doesDAXConnWrite(ev.Stk) {
		t.log("idle->active   %s", ev)
		t.queue = nil
		t.queue = append(t.queue, ev)
		return t.active
	}

	t.log("idle->idle     %s", ev)
	return t.idle
}

func (t *daxTracker) active(ev *trace.Event) trackerState {
	done := func() {
		if fn := t.flush; fn != nil {
			for i := len(t.queue) - 1; i >= 0; i-- {
				ev := t.queue[i]
				if ev.Type == trace.EvHeapAlloc {
					// TODO: backport to others
					// TODO: factor out common parts now that they're more apparent
					t.queue = t.queue[:i]
					continue
				}
				break
			}
			if len(t.queue) >= 2 {
				fn(t.queue)
			}
		}

		t.queue = nil
	}

	t.queue = append(t.queue, ev)

	if ev.Type == trace.EvHeapAlloc ||
		ev.Type == trace.EvGoStart {
		return t.active
	}

	switch {
	case ev.Type == trace.EvGoEnd:
		t.queue = t.queue[:len(t.queue)-1]
		t.log("active->nil    %s", ev)
		done()
		return nil
	case ev.Type == trace.EvHeapAlloc || ev.Type == trace.EvGoStart:
		t.log("active->active %s", ev)
		return t.active
	case ev.Type == trace.EvGoSysCall || ev.Type == trace.EvGoBlockNet:
		if doesDAXConnRead(ev.Stk) {
			t.log("active->active %s", ev)
			return t.active
		}
	}

	t.queue = t.queue[:len(t.queue)-1]
	t.log("active->idle   %s", ev)
	done()
	return t.idle
}
