package main

import (
	"flag"
	"fmt"
	"log"
	"os"
	"regexp"
	"sort"

	"github.com/google/pprof/profile"
)

func main() {
	heapFile := flag.String("heap", "/dev/null", "Name of heap profile file")
	digraph := flag.Bool("digraph", false, "Format output as a space-separated directed graph")
	showRaw := flag.Bool("raw-http", true, "Include raw HTTP calls (formed without a generated SDK)")
	showSDK := flag.Bool("sdk", true, "Include calls built with a generated SDK")
	flag.Parse()

	log.SetFlags(0)

	f, err := os.Open(*heapFile)
	if err != nil {
		log.Fatalf("os.Open(%q); err = %v", *heapFile, err)
	}
	defer f.Close()

	heap, err := profile.Parse(f)
	if err != nil {
		log.Fatalf("profile.Parse(%q); err = %v", *heapFile, err)
	}

	(&callFinder{
		digraph: *digraph,
		showRaw: *showRaw,
		showSDK: *showSDK,
	}).examine(heap)
}

type callFinder struct {
	digraph bool
	showRaw bool
	showSDK bool
}

func (f *callFinder) examine(heap *profile.Profile) {
	newRequest := regexp.MustCompile(`^(|.*/vendor/)` + regexp.QuoteMeta(`net/http.NewRequest`) + `$`)

	ours := heap.Copy()
	ours.FilterSamplesByName(newRequest, nil, nil, nil)

	ours = ours.Compact()

	matches := make(map[uint64]struct{})
	for _, loc := range ours.Location {
		for _, l := range loc.Line {
			if newRequest.MatchString(l.Function.Name) {
				matches[loc.ID] = struct{}{}
			}
		}
	}

	httpCallers := make(map[string][]string)
	sdkCallers := make(map[string][]string)
	for _, sam := range ours.Sample {
		locs := libraryCallstack(newRequest, sam.Location)
		if locs == nil {
			continue
		}

		var prev, here *funcLine
		done := false
		eachLine(locs, func(i int, loc *profile.Location, j int, line profile.Line) {
			if done {
				return
			}
			here = &funcLine{location: loc, line: line}
			if !isGeneratedClient(here) && !isRawHTTPClient(here) {
				done = true
				return
			}
			prev = here
		})

		if !done || prev == nil || here == nil {
			f.logf("\n")
			f.logf("rootless client call")
			eachLine(locs, func(i int, loc *profile.Location, j int, line profile.Line) {
				f.logf("    %s", line.Function.Name)
			})
			continue
		}

		callers := httpCallers
		if isGeneratedClient(prev) {
			callers = sdkCallers
		}
		callee := prev.line.Function.Name
		caller := here.line.Function.Name
		callers[callee] = append(callers[callee], caller)
	}

	if f.showRaw {
		f.logf("raw http call sites")
		f.displayCalls(httpCallers)
	}
	if f.showSDK {
		f.logf("sdk call sites")
		f.displayCalls(sdkCallers)
	}
}

func (f *callFinder) logf(format string, v ...interface{}) {
	if !f.digraph {
		log.Printf(format, v...)
	}
}

func (f *callFinder) displayCalls(callers map[string][]string) {
	var callees []string
	for callee := range callers {
		callees = append(callees, callee)
	}
	sort.Strings(callees)
	for _, callee := range callees {
		callers := callers[callee]

		m := make(map[string]struct{})
		for _, caller := range callers {
			m[caller] = struct{}{}
		}
		callers = nil
		for caller := range m {
			callers = append(callers, caller)
		}
		sort.Strings(callers)

		if f.digraph {
			for _, caller := range callers {
				fmt.Printf("%s %s\n", callee, caller)
			}
		} else {
			f.logf("%s", callee)
			for _, caller := range callers {
				f.logf("    %s", caller)
			}
			f.logf("\n")
		}
	}
}

func libraryCallstack(maxDepth *regexp.Regexp, locs []*profile.Location) []*profile.Location {
	var out []*profile.Location

	eachLineRev(locs, func(i int, loc *profile.Location, j int, line profile.Line) {
		if out != nil {
			return
		}
		if maxDepth.MatchString(line.Function.Name) {
			out = append([]*profile.Location(nil), locs[i:]...)
			here := *out[0]
			here.Line = out[0].Line[j:]
			out[0] = &here
		}
	})

	return out
}

func eachLine(locs []*profile.Location, fn func(i int, loc *profile.Location, j int, line profile.Line)) {
	for i, loc := range locs {
		for j, l := range loc.Line {
			fn(i, loc, j, l)
		}
	}
}

func eachLineRev(locs []*profile.Location, fn func(i int, loc *profile.Location, j int, line profile.Line)) {
	for i := len(locs) - 1; i >= 0; i-- {
		loc := locs[i]
		for j := len(loc.Line) - 1; j >= 0; j-- {
			l := loc.Line[j]
			fn(i, loc, j, l)
		}
	}
}

func isGeneratedClient(l *funcLine) bool {
	for _, fn := range [](func(l *funcLine) bool){
		isAWSClient, isTwirpClient, isAlvinClient,
	} {
		if fn(l) {
			return true
		}
	}
	return false
}

func isRawHTTPClient(l *funcLine) bool {
	for _, fn := range [](func(l *funcLine) bool){
		isNetHTTPClient, isHTTPHelper,
	} {
		if fn(l) {
			return true
		}
	}
	return false
}

func isTwirpClient(l *funcLine) bool {
	re := regexp.MustCompile(`^.*` + regexp.QuoteMeta(`.twirp.go`) + `$`)
	if !re.MatchString(l.line.Function.Filename) {
		return false
	}

	return funcNameMatch(l,
		`.*`+`\.newRequest`,
		`.*`+`\.do[^/\.]*Request`,
		`.*`+`\.\(\*[^/\.\)]*Client\)\.[^/\.\(\)]*`,
	)
}

func isAlvinClient(l *funcLine) bool {
	return funcNameMatch(l,
		regexp.QuoteMeta(`code.justin.tv/common/alvin/restclient.(*roundTripper).`)+`.*`,
		regexp.QuoteMeta(`code.justin.tv/video/usherapi/rpc/usher.(*alvinUsherClient).`)+`.*`,
	)
}

func isNetHTTPClient(l *funcLine) bool {
	return funcNameMatch(l,
		regexp.QuoteMeta(`net/http.NewRequest`),
		regexp.QuoteMeta(`net/http.`)+`(\(\*Client\)\.|)(Get|Head|Post|PostForm)`,

		// Alvin-related net/http functions:
		regexp.QuoteMeta(`net/http.send`),
		regexp.QuoteMeta(`net/http.(*Client).send`),
		regexp.QuoteMeta(`net/http.(*Client).do`),
		regexp.QuoteMeta(`net/http.(*Client).Do`),
	)
}

func isHTTPHelper(l *funcLine) bool {
	return funcNameMatch(l,
		regexp.QuoteMeta(`code.justin.tv/foundation/twitchclient`)+`[^/]*`+regexp.QuoteMeta(`.NewRequest`),
	)
}

func isAWSClient(l *funcLine) bool {
	return funcNameMatch(l,
		regexp.QuoteMeta(`github.com/aws/aws-sdk-go/aws/request`)+`[^/]*`,
		regexp.QuoteMeta(`github.com/aws/aws-sdk-go/aws/client`)+`[^/]*`,
		regexp.QuoteMeta(`github.com/aws/aws-sdk-go/service/`)+`[^/]*`,
		regexp.QuoteMeta(`github.com/aws/aws-sdk-go/service/s3/s3manager`)+`[^/]*`,
		regexp.QuoteMeta(`github.com/aws/aws-sdk-go/aws/ec2metadata.(*EC2Metadata).GetMetadata`),

		// Handwritten in the style of aws-sdk-go
		regexp.QuoteMeta(`code.justin.tv/video/mwsclient.(*MWS).PutMetricDataForAggregation`),
	)

	// For now, ignore places where the SDK calls endpoints on the way to calling other endpoints:
	//
	//     github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds.
}

func funcNameMatch(l *funcLine, regexps ...string) bool {
	for _, pattern := range regexps {
		re := regexp.MustCompile(`^(|.*/vendor/)` + pattern + `$`)
		if re.MatchString(l.line.Function.Name) {
			return true
		}
	}

	return false
}

type funcLine struct {
	location *profile.Location
	line     profile.Line
}
