package main

import (
	"flag"
	"fmt"
	"io"
	"log"
	"os"
	"os/signal"
	"sort"
	"syscall"
	"time"

	"code.justin.tv/common/chitin"
	_ "code.justin.tv/common/golibs/bininfo"
	"code.justin.tv/release/trace/api"
	"code.justin.tv/release/trace/clientscripts/internal/analysis/stats"
	"github.com/spenczar/tdigest"
	"golang.org/x/net/context"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
)

var (
	apiHostPort = flag.String("api", "trace-api.prod.us-west2.justin.tv:11143", "Host and port for the Trace API")
	service     = flag.String("svc", "code.justin.tv/web/web", "The service to analyze")
	host        = flag.String("host", "", "Optionally report a comparison between processes on a single host")
	clientView  = flag.Bool("client", false, "Show client's perspective of call durations")
	sampling    = flag.Float64("sample", 0.1, "Sampling rate")
	duration    = flag.Duration("t", time.Second, "How long to capture data before returning a result")
	quantile    = flag.Float64("q", 0.90, "Timing quantile to report, e.g. 0.90 for p90")
	histogram   = flag.Bool("histogram", false, "Print a histogram of the distribution of inter-request arrival durations")
)

const (
	compression = 100 // sane default for tdigest
)

func main() {
	flag.Parse()

	err := chitin.ExperimentalTraceProcessOptIn()
	if err != nil {
		log.Fatalf("trace enable err=%q", err)
	}

	conn, err := grpc.Dial(*apiHostPort, grpc.WithInsecure())
	if err != nil {
		log.Fatalf("dial err=%q", err)
	}
	defer conn.Close()
	client := api.NewTraceClient(conn)
	ctx, cancel := context.WithTimeout(context.Background(), *duration)

	sigch := make(chan os.Signal, 1)
	signal.Notify(sigch, syscall.SIGINT)
	go func() {
		<-sigch
		cancel()
	}()

	var comparison api.Comparison
	comparison.ServiceName = &api.StringComparison{Value: *service}
	if *host != "" {
		comparison.ServiceHost = &api.StringComparison{Value: *host}
	}

	firehose, err := client.Firehose(ctx, &api.FirehoseRequest{
		Sampling: *sampling,
		Query: &api.Query{
			Comparisons: []*api.Comparison{&comparison},
		},
	})
	if err != nil {
		log.Fatalf("firehose err=%q", err)
	}

	result := consumeFirehose(firehose)
	if _, err := result.WriteTo(os.Stdout); err != nil {
		log.Fatalf("dump err=%q", err)
	}
}

func match(call *api.Call) bool {
	if svc := call.GetSvc(); svc != nil && svc.Name != *service {
		return false
	}
	if *host != "" {
		if svc := call.GetSvc(); svc != nil && svc.Host != *host {
			return false
		}
	}
	return true
}

func consumeFirehose(firehose api.Trace_FirehoseClient) io.WriterTo {
	var tx *api.Transaction
	var err error

	var d = &data{
		procs:  make(map[key]*datum),
		starts: make(map[key][]time.Time),
		datum: datum{
			digest: tdigest.New(compression),
		},
	}

	for tx, err = firehose.Recv(); err == nil; tx, err = firehose.Recv() {
		var search func(*api.Call)
		search = func(call *api.Call) {
			if match(call) {
				var k key
				if svc := call.GetSvc(); svc != nil {
					k.host = svc.Host
				}
				if svc := call.GetSvc(); svc != nil && *host != "" {
					k.pid = svc.Pid
				}

				dat, ok := d.procs[k]
				if !ok {
					dat = &datum{
						digest: tdigest.New(compression),
					}
					d.procs[k] = dat
				}
				var dur time.Duration
				if *clientView {
					dur = api.ClientDuration(call)
				} else {
					dur = api.ServiceDuration(call)
				}
				dat.digest.Add(dur.Seconds(), 1)
				dat.count++
				d.datum.digest.Add(dur.Seconds(), 1)
				d.datum.count++

				d.starts[k] = append(d.starts[k], api.ServiceStart(call))
			}
			for _, sub := range call.Subcalls {
				search(sub)
			}
		}
		search(tx.Root)
	}

	if err != nil {
		switch grpc.Code(err) {
		case codes.DeadlineExceeded:
		default:
			d.err = err
		}
	}
	return d
}

type key struct {
	host string
	pid  int32
}

func (k key) String() string {
	if k.pid == 0 {
		return k.host
	}
	return fmt.Sprintf("%s/%-5d", k.host, k.pid)
}

type sorter struct {
	Length int
	SwapFn func(i, j int)
	LessFn func(i, j int) bool
}

func (s sorter) Len() int           { return s.Length }
func (s sorter) Swap(i, j int)      { s.SwapFn(i, j) }
func (s sorter) Less(i, j int) bool { return s.LessFn(i, j) }

func MakeInterface(n int, swap func(i, j int), less func(i, j int) bool) sort.Interface {
	return sorter{Length: n, SwapFn: swap, LessFn: less}
}

type data struct {
	procs  map[key]*datum
	starts map[key][]time.Time
	datum  datum
	err    error
}

type datum struct {
	digest tdigest.TDigest
	count  int
}

func (d *data) WriteTo(w io.Writer) (int64, error) {
	var nn int64

	const total = "total"
	var longest = len(total)

	var procs []key
	for proc := range d.procs {
		procs = append(procs, proc)
		if name := proc.String(); len(name) > longest {
			longest = len(name)
		}
	}
	sort.Sort(MakeInterface(len(procs), func(i, j int) { procs[i], procs[j] = procs[j], procs[i] }, func(i, j int) bool {
		if procs[i].host != procs[j].host {
			return procs[i].host < procs[j].host
		}
		return procs[i].pid < procs[j].pid
	}))

	dur := time.Duration(float64(time.Second) * d.datum.digest.Quantile(*quantile))
	n, err := fmt.Fprintf(w, "% *s: %10.3fms, %5d calls\n", longest, total,
		dur.Seconds()*float64(time.Second/time.Millisecond), d.datum.count)
	nn += int64(n)
	if err != nil {
		return nn, err
	}

	for _, proc := range procs {
		if *histogram {
			starts := d.starts[proc]
			if l := len(starts); l > 10 {
				starts = starts[l/10 : l-l/10]
			}
			sort.Sort(MakeInterface(len(starts), func(i, j int) { starts[i], starts[j] = starts[j], starts[i] }, func(i, j int) bool {
				si, sj := starts[i], starts[j]
				return si.Before(sj)
			}))
			h := stats.NewHistogram(20)
			for i := 1; i < len(starts); i++ {
				delta := starts[i].Sub(starts[i-1])
				h.Add(int64(delta))
			}
			it := h.Iterator()
			for it.NextPositive() {
				if it.Weight > 0 {
					n, err := fmt.Fprintf(w, "weight=%d bounds=[%d,%d)\n", it.Weight, it.LowerBound, it.UpperBound)
					nn += int64(n)
					if err != nil {
						return nn, err
					}
				}
			}
		}

		datum := d.procs[proc]
		dur := time.Duration(float64(time.Second) * datum.digest.Quantile(*quantile))
		n, err := fmt.Fprintf(w, "% *s: %10.3fms, %5d calls\n",
			longest, proc,
			dur.Seconds()*float64(time.Second/time.Millisecond), datum.count)
		nn += int64(n)
		if err != nil {
			return nn, err
		}
	}

	if d.err != nil {
		return nn, d.err
	}
	return nn, nil
}
