package main

import (
	"bytes"
	"fmt"
	"html/template"
	"log"
	"math"
	"net/http"
	"net/url"
	"path"
	"sort"
	"strconv"
	"strings"
	"time"

	"code.justin.tv/common/chitin"
	"code.justin.tv/release/trace/api"
	"code.justin.tv/rhys/nursery/pick"
	"github.com/ajstarks/svgo"
	"github.com/golang/protobuf/proto"
	"github.com/kr/logfmt"
	"github.com/syndtr/goleveldb/leveldb"
	"github.com/syndtr/goleveldb/leveldb/iterator"
	"github.com/syndtr/goleveldb/leveldb/util"
	"golang.org/x/net/context"
)

type levelHandler struct {
	db        *leveldb.DB
	uriPrefix string
}

type levelRequest struct {
	// For /tx/
	txid string

	// For /svc/
	svcname string
	rpc     string
}

type levelResponse struct {
	httpStatus  int
	contentType string
	body        string
}

func (h *levelHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	ctx := chitin.Context(w, r)

	if h.db == nil {
		http.Error(w, "no db available", http.StatusInternalServerError)
		return
	}
	uri := path.Join("/", strings.TrimPrefix(r.URL.Path, h.uriPrefix))
	if len(uri) == len(r.URL.Path) {
		http.Error(w, "incorrect prefix", http.StatusNotFound)
		return
	}

	var (
		resp *levelResponse
		err  error
	)
	switch {
	case strings.HasPrefix(uri, "/tx/"):
		if r.Method != "GET" {
			http.Error(w, "invalid method", http.StatusMethodNotAllowed)
			return
		}
		txid := strings.TrimPrefix(uri, "/tx/")
		if !validTxid(txid) {
			http.Error(w, "invalid txid", http.StatusBadRequest)
			return
		}
		resp, err = h.serveTX(ctx, &levelRequest{
			txid: txid,
		})
	case strings.HasPrefix(uri, "/svc/"):
		if r.Method != "GET" {
			http.Error(w, "invalid method", http.StatusMethodNotAllowed)
			return
		}
		svcname := strings.TrimPrefix(uri, "/svc/")
		rpc := r.URL.Query().Get("rpc")
		req := &levelRequest{
			svcname: svcname,
			rpc:     rpc,
		}
		resp, err = h.serveSvc(ctx, req)
	case uri == "/":
		resp, err = h.serveIndex(ctx, &levelRequest{})
	default:
		resp, err = h.serveMissing(ctx, &levelRequest{})
	}

	if err != nil {
		http.Error(w, "oops", http.StatusInternalServerError)
		return
	}
	ct := resp.contentType
	if ct == "" {
		ct = "text/plain; charset=utf-8"
	}
	w.Header().Set("Content-Type", ct)
	w.WriteHeader(resp.httpStatus)
	fmt.Fprintf(w, "%s", resp.body)
}

func (h *levelHandler) serveTX(ctx context.Context, r *levelRequest) (*levelResponse, error) {
	prefix := fmt.Sprintf("txid=%q", r.txid)
	it := h.db.NewIterator(util.BytesPrefix([]byte(prefix)), nil)
	defer it.Release()

	var (
		n     int
		tx    api.Transaction
		found bool
		err   error
	)
	for it.Next() {
		if !found {
			err = proto.Unmarshal(it.Value(), &tx)
			if err == nil {
				found = true
			}
		}
		n++
	}

	if !found {
		return &levelResponse{
			httpStatus: http.StatusNotFound,
			body:       "transaction not found",
		}, nil
	}

	//

	var buf bytes.Buffer

	tmpl := template.Must(template.New("").Parse(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="google" content="notranslate" />
<title>{{.Title}}</title>
</head>
<body>

{{.Svg}}
<br />

<pre>
{{.ProtoString}}
</pre>

</body>
</html>
`))
	err = tmpl.Execute(&buf, struct {
		Title     string
		URIPrefix string

		Svg template.HTML

		ProtoString string
	}{
		Title:     fmt.Sprintf("txid %s", r.txid),
		URIPrefix: h.uriPrefix,

		Svg: txSvg(ctx, &tx),

		ProtoString: proto.MarshalTextString(&tx),
	})
	if err != nil {
		log.Printf("Execute err=%q", err)
		return nil, err
	}

	return &levelResponse{
		httpStatus:  http.StatusOK,
		contentType: "text/html; charset=utf-8",
		body:        buf.String(),
	}, nil
}

func (h *levelHandler) serveSvc(ctx context.Context, r *levelRequest) (*levelResponse, error) {
	prefix := fmt.Sprintf("index svcname=%q", r.svcname)
	it := h.db.NewIterator(util.BytesPrefix([]byte(prefix)), nil)
	defer it.Release()

	spans, err := iterSpans(ctx, it, 10000)
	if err != nil {
		return nil, err
	}

	sort.Strings(spans)
	it = h.db.NewIterator(nil, nil)
	defer it.Release()

	calls, err := iterCalls(ctx, it, spans)
	if err != nil {
		return nil, err
	}

	split := new(splitter)
	*split = *byServerDuration

	if r.rpc != "" {
		var matchCalls []*call
		for _, c := range calls {
			if rpcName(c.Call) != r.rpc {
				continue
			}
			matchCalls = append(matchCalls, c)
		}
		calls = matchCalls
		split.nameFn = func(c *api.Call) interface{} {
			deps := make(map[string][3]string)
			for _, sub := range c.Subcalls {
				dest := destOf(sub)
				rpc := rpcName(sub)
				dep := [3]string{
					dest.svcname,
					rpc,
					dest.sentTo, // this one will be blank if the svcname is set
				}
				deps[fmt.Sprintf("%q %q %q", dep[0], dep[1], dep[2])] = dep
			}
			var deplist []string
			for k := range deps {
				deplist = append(deplist, k)
			}
			sort.Strings(deplist)
			return strings.Join(deplist, ", ")
		}
	}

	var buf bytes.Buffer

	tmpl := template.Must(template.New("").Funcs(map[string]interface{}{
		"DetailURI": func(r *row) string {
			name, ok := r.Name.([2]string)
			if !ok {
				return ""
			}

			return (&url.URL{
				Path: name[0],
				RawQuery: url.Values{
					"rpc": []string{name[1]},
				}.Encode(),
			}).String()
		},
	}).Parse(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="google" content="notranslate" />
<title>{{.Title}}</title>
</head>
<body>

{{range .Rows}}
	<br />
	<a href="{{$.URIPrefix}}/svc/{{DetailURI .}}">{{printf "Name=%q Count=%d TotalCost=%d" .Name .Count .TotalCost}}</a><br />
	{{.Svg}}<br />
	{{range .Buckets}}
		Cost {{.MinCost}}–{{.MaxCost}} — Count {{.Count}}
		{{range $i, $ex := .Picker.Picked}}
			{{if lt $i 4}}
				<a href="{{$.URIPrefix}}/tx/{{$ex.Txid}}?path={{$ex.Path}}">{{printf "ex%d" $i}}</a>
			{{end}}
		{{end}}
		<br />
	{{end}}
{{end}}

</body>
</html>
`))
	err = tmpl.Execute(&buf, struct {
		Title     string
		URIPrefix string

		Calls []*call

		Rows []*row
	}{
		Title:     fmt.Sprintf("Trace—%s", r.svcname),
		URIPrefix: h.uriPrefix,

		Calls: calls,

		Rows: split.slice(ctx, calls),
	})
	if err != nil {
		log.Printf("Execute err=%q", err)
		return nil, err
	}

	// fmt.Fprintf(&buf, "n=%d l=%d calls=%d\n", n, len(spans), len(calls))
	// for _, call := range calls {
	// 	fmt.Fprintf(&buf, "%s\n", proto.CompactTextString(call))
	// }

	return &levelResponse{
		httpStatus:  http.StatusOK,
		contentType: "text/html; charset=utf-8",
		body:        buf.String(),
	}, nil
}

func (h *levelHandler) serveIndex(ctx context.Context, r *levelRequest) (*levelResponse, error) {
	it := h.db.NewIterator(util.BytesPrefix([]byte("index svcname=\"")), nil)
	defer it.Release()

	var svcs []string
	for it.Last(); it.Prev() && ctx.Err() == nil; {
		var key indexKey
		if err := key.parse(it.Key()); err != nil {
			continue
		}
		svcs = append(svcs, key.Svcname)
		// Seek to the beginning of this service's index, continue the backwards search from there
		it.Seek([]byte(fmt.Sprintf("index svcname=%q", key.Svcname)))
	}
	sort.Strings(svcs)

	var buf bytes.Buffer

	tmpl := template.Must(template.New("").Parse(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="google" content="notranslate" />
<title>{{.Title}}</title>
</head>
<body>

{{range .Svcnames}}
	{{if ne . ""}}
		<a href="{{$.URIPrefix}}/svc/{{.}}">{{.}}</a><br />
	{{else}}
		<!-- <a href="{{$.URIPrefix}}/svc/">unspecified</a><br /> -->
	{{end}}
{{end}}

</body>
</html>
`))
	err := tmpl.Execute(&buf, struct {
		Title     string
		URIPrefix string

		Svcnames []string
	}{
		Title:     "Trace index",
		URIPrefix: h.uriPrefix,

		Svcnames: svcs,
	})
	if err != nil {
		log.Printf("Execute err=%q", err)
		return nil, err
	}

	return &levelResponse{
		httpStatus:  http.StatusOK,
		contentType: "text/html; charset=utf-8",
		body:        buf.String(),
	}, nil
}

func (h *levelHandler) serveMissing(ctx context.Context, r *levelRequest) (*levelResponse, error) {
	return &levelResponse{}, nil
}

func txSvg(ctx context.Context, tx *api.Transaction) template.HTML {
	var (
		width  = 1500
		height = 1000

		waterfallWidth = 300
		textX          = 306
	)

	var buf bytes.Buffer
	g := svg.New(&buf)
	g.Start(width, height)

	var calls []*api.Call
	for grey := []*api.Call{tx.GetRoot()}; len(grey) > 0; {
		l := len(grey)
		for i := 0; i < l; i++ {
			c := grey[i]
			if c == nil {
				continue
			}
			calls = append(calls, c)
			grey = append(grey, c.Subcalls...)
		}
		grey = grey[l:]
	}
	sort.Sort(byPath(calls))

	var start, end int64
	for _, c := range calls {
		for _, ts := range [][]*api.Timestamp{c.ClientTimestamps, c.ServiceTimestamps} {
			for _, t := range ts {
				if t.Time == 0 {
					continue
				}
				if t.Time < start || start == 0 {
					start = t.Time
				}
				if t.Time > end || end == 0 {
					end = t.Time
				}
			}
		}
	}
	if end == start {
		end = start + 1
	}

	xval := func(t int64) int {
		p := float64(t-start) / float64(end-start)
		return int(p * float64(waterfallWidth))
	}

	for i, c := range calls {
		y := 20 + 10*i
		if i%2 == 1 {
			x1, x2 := 0, width
			g.Line(x1, y, x2, y, "stroke:black; stroke-opacity:0.05; stroke-width:10")
		}
		// TODO: confirm that the timestamps cover a start event and an end
		// event.
		if ts := c.ClientTimestamps; len(ts) >= 1 {
			x1 := xval(ts[0].Time)
			x2 := xval(ts[len(ts)-1].Time)
			if x2 == x1 {
				x2 = x1 + 1
			}
			g.Line(x1, y, x2, y, "stroke:black; stroke-width:2")
		}
		if ts := c.ServiceTimestamps; len(ts) >= 1 {
			x1 := xval(ts[0].Time)
			x2 := xval(ts[len(ts)-1].Time)
			if x2 == x1 {
				x2 = x1 + 1
			}
			g.Line(x1, y, x2, y, "stroke:green; stroke-opacity:0.5; stroke-width:6")
		}

		var t0 int64
		for _, ts := range [][]*api.Timestamp{c.ClientTimestamps, c.ServiceTimestamps} {
			for _, t := range ts {
				if t.Time == 0 {
					continue
				}
				if t.Time < t0 || t0 == 0 {
					t0 = t.Time
				}
			}
		}
		cd, sd := api.ClientDuration(c), api.ServiceDuration(c)
		at := time.Duration(t0-start) * time.Nanosecond
		if at < 0 {
			at = 0
		}

		desc := rpcSummary(c)

		d := destOf(c)
		g.Text(textX, y+2,
			fmt.Sprintf("%s @ %0.3fms — c %0.3fms / s %0.3fms — %s%s — %s",
				callpath(c.Path),
				float64(at/time.Microsecond)/1000,
				float64(cd/time.Microsecond)/1000,
				float64(sd/time.Microsecond)/1000,
				d.svcname, d.sentTo,
				desc),
			"font-size:8; fill:black")
	}

	g.End()
	return template.HTML(buf.String())
}

func rpcSummary(c *api.Call) string {
	if p := c.GetParams().GetHttp(); p != nil {
		return fmt.Sprintf("HTTP %s (%d): %s", p.Method, p.Status, p.Route)
	}
	if p := c.GetParams().GetGrpc(); p != nil {
		return fmt.Sprintf("gRPC: %s", p.Method)
	}
	if p := c.GetParams().GetSql(); p != nil {
		return fmt.Sprintf("SQL %s@%s: %s", p.Dbuser, p.Dbname, p.StrippedQuery)
	}
	if p := c.GetParams().GetMemcached(); p != nil {
		return fmt.Sprintf("Memcached: %s %d/%d", p.Command, p.NKeysResponse, p.NKeysRequest)
	}
	return ""
}

type destination struct {
	svcname string
	sentTo  string
}

func destOf(c *api.Call) destination {
	var dest destination

	if svc := c.GetSvc(); svc != nil {
		dest.svcname = svc.Name
	}
	if dest.svcname == "" {
		dest.sentTo = c.RequestSentTo
	}

	return dest
}

//

type call struct {
	Txid string
	Path callpath
	Call *api.Call
}

type callpath []uint32

func (p callpath) String() string {
	var segs []string
	for _, p := range p {
		segs = append(segs, fmt.Sprintf(".%d", p))
	}
	return strings.Join(segs, "")
}

type splitter struct {
	nameFn   func(*api.Call) interface{}
	costFn   func(*api.Call) int64
	bucketFn func(int64) (min, max int64)

	pickCount int
}

func (s *splitter) name(c *api.Call) interface{} {
	if s.nameFn == nil {
		return nil
	}
	name := s.nameFn(c)
	// panic early if it's not comparable
	_ = name == name
	return name
}

func (s *splitter) cost(c *api.Call) int64 {
	if s.costFn == nil {
		return 0
	}
	return s.costFn(c)
}

func (s *splitter) splitCalls(ctx context.Context, calls <-chan *call) []*row {
	rows := make(map[interface{}]*row)
	var rs []*row

	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	for ctx.Err() == nil {
		select {
		case <-ctx.Done():
			break
		case c, ok := <-calls:
			if !ok {
				cancel()
				break
			}
			name := s.name(c.Call)
			cost := s.cost(c.Call)

			// log.Printf("name=%q cost=%d", name, cost)

			r, ok := rows[name]
			if !ok {
				r = &row{
					Name: name,
				}
				rows[name] = r
				rs = append(rs, r)
			}

			n := sort.Search(len(r.Buckets), func(n int) bool {
				return cost < r.Buckets[n].MinCost
			})
			// log.Printf("n=%d l=%d", n, len(r.Buckets))
			n -= 1
			var bk *bucket
			if n >= 0 && n < len(r.Buckets) {
				bk = r.Buckets[n]
				if bk.MaxCost <= cost {
					bk = nil
				}
			}
			if bk == nil {
				min, max := s.bucketFn(cost)
				if cost < min || cost >= max {
					continue
				}
				bk = &bucket{
					Picker:  pick.New(s.pickCount, nil),
					MinCost: min,
					MaxCost: max,
				}
				r.Buckets = append(r.Buckets, bk)
				sort.Sort(byMinCost(r.Buckets))
			}

			// Prep work is done, record the call

			r.Count++
			r.TotalCost += cost

			bk.Count++
			bk.Picker.Add(c)
		}
	}

	sort.Sort(sort.Reverse(byTotalCost(rs)))

	return rs
}

type row struct {
	Name      interface{}
	Count     int
	TotalCost int64
	Buckets   []*bucket
}

func (r *row) Svg() template.HTML {
	var buf bytes.Buffer
	g := svg.New(&buf)
	g.Start(20+4*len(r.Buckets), 60)

	ac := chart{
		tick:   4,
		height: 60,

		yscale: (log10{logmin: -0.2, logmax: 6, perDecade: 10}).scale,
	}

	var data []int
	for _, bk := range r.Buckets {
		// TODO: insert zeros for empty buckets
		// TODO: x axis magnitude indicators
		// TODO: translate along x axis appropriately
		// TODO: let row struct keep track of the bucketFn that created it
		data = append(data, bk.Count)
	}

	ac.plotArea(g, data, "fill:orange; fill-opacity:1.0")

	g.End()
	return template.HTML(buf.String())
}

type byPath []*api.Call

func (s byPath) Len() int      { return len(s) }
func (s byPath) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s byPath) Less(i, j int) bool {
	pi, pj := s[i].Path, s[j].Path
	for k := 0; k < len(pi) && k < len(pj); k++ {
		if pi[k] != pj[k] {
			return pi[k] < pj[k]
		}
	}
	return len(pi) < len(pj)
}

type byTotalCost []*row

func (s byTotalCost) Len() int           { return len(s) }
func (s byTotalCost) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
func (s byTotalCost) Less(i, j int) bool { return s[i].TotalCost < s[j].TotalCost }

type bucket struct {
	Picker *pick.Stream

	Count   int
	MinCost int64
	MaxCost int64
}

type byMinCost []*bucket

func (s byMinCost) Len() int           { return len(s) }
func (s byMinCost) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
func (s byMinCost) Less(i, j int) bool { return s[i].MinCost < s[j].MinCost }

func rpcName(c *api.Call) string {
	if p := c.GetParams().GetHttp(); p != nil {
		if p.Route != "" {
			return p.Route
		}
		return fmt.Sprintf("HTTP %s", p.Method.String())
	}
	if p := c.GetParams().GetGrpc(); p != nil {
		return p.Method
	}
	if p := c.GetParams().GetSql(); p != nil {
		return "SQL"
	}
	if p := c.GetParams().GetMemcached(); p != nil {
		return "Memcached"
	}
	return ""
}

var byServerDuration = &splitter{
	nameFn: func(c *api.Call) interface{} {
		if c.GetSvc() == nil {
			return nil
		}
		svcname := c.GetSvc().Name
		rpc := rpcName(c)

		return [2]string{svcname, rpc}

	},
	costFn: func(c *api.Call) int64 {
		return api.ServiceDuration(c).Nanoseconds()
	},
	bucketFn: func(cost int64) (min, max int64) {
		if cost <= 0 {
			return 0, 0
		}
		l10 := math.Log2(10)
		lg := math.Log2(float64(cost)) / l10
		bn := int(lg * 10)
		min = int64(math.Exp2((float64(bn) / 10) * l10))
		max = int64(math.Exp2((float64(bn+1) / 10) * l10))
		return min, max
	},

	pickCount: 4,
}

var bySubcallCount = &splitter{
	nameFn: func(c *api.Call) interface{} {
		if c.GetSvc() == nil {
			return nil
		}
		svcname := c.GetSvc().Name
		rpc := rpcName(c)

		return [2]string{svcname, rpc}
	},
	costFn: func(c *api.Call) int64 {
		return int64(len(c.Subcalls))
	},
	bucketFn: func(cost int64) (min, max int64) {
		switch {
		case cost < 0:
			return 0, 0
		case cost < 10:
			return cost, cost + 1
		case cost < 100:
			return (cost / 10) * 10, (cost/10 + 1) * 10
		}
		l10 := math.Log2(10)
		lg := math.Log2(float64(cost)) / l10
		bn := int(lg * 10)
		min = int64(math.Exp2((float64(bn) / 10) * l10))
		max = int64(math.Exp2((float64(bn+1) / 10) * l10))
		return min, max
	},

	pickCount: 4,
}

func (s *splitter) slice(ctx context.Context, calls []*call) []*row {
	ch := make(chan *call, len(calls))
	for _, c := range calls {
		ch <- c
	}
	close(ch)
	rows := s.splitCalls(ctx, ch)

	return rows
}

//

func iterSpans(ctx context.Context, it iterator.Iterator, pickCount int) ([]string, error) {
	pickTxid := pick.New(pickCount, nil)

	var (
		n     int
		spans []string
		err   error
	)
	for it.Next() {
		if ctx.Err() != nil {
			break
		}

		n++
		fn := func() interface{} {
			var key indexKey
			err = key.parse(it.Key())
			if err != nil {
				return nil
			}
			if !validTxid(key.Txid) {
				return nil
			}

			span := key.Txid + key.Path
			return span
		}
		pickTxid.AddFn(fn)
	}

	for _, elt := range pickTxid.Picked() {
		spans = append(spans, elt.(string))
	}

	if err := it.Error(); err != nil {
		return spans, err
	}

	return spans, nil
}

type indexKey struct {
	Index   bool   `logfmt:"index"`
	Svcname string `logfmt:"svcname"`
	Time    string `logfmt:"time"`
	Txid    string `logfmt:"txid"`
	Path    string `logfmt:"path"`
}

func (k *indexKey) parse(b []byte) error {
	return logfmt.Unmarshal(b, k)
}

func (k *indexKey) HandleLogfmt(key, val []byte) error {
	switch string(key) {
	case "index":
		k.Index = true
	case "svcname":
		k.Svcname = string(val)
	case "time":
		k.Time = string(val)
	case "txid":
		k.Txid = string(val)
	case "path":
		k.Path = string(val)
	}
	return nil
}

func iterCalls(ctx context.Context, it iterator.Iterator, spans []string) ([]*call, error) {
	var calls []*call
	for _, span := range spans {
		if ctx.Err() != nil {
			break
		}

		txid, p := decodeSpan(span)

		key := []byte(fmt.Sprintf("txid=%q", txid))
		if !it.Seek(key) {
			continue
		}
		for {
			if !bytes.HasPrefix(it.Key(), key) {
				break
			}

			var tx api.Transaction
			err := proto.Unmarshal(it.Value(), &tx)
			if err != nil {
				continue
			}

			c := findCall(&tx, p)
			if c != nil {
				calls = append(calls, &call{
					Txid: txid,
					Path: p,
					Call: c,
				})
			}

			if !it.Next() {
				break
			}
		}
	}

	if err := it.Error(); err != nil {
		return calls, err
	}
	return calls, nil
}

func validTxid(txid string) bool {
	if len(txid) != 32 {
		return false
	}
	for i := 0; i < len(txid); i++ {
		c := txid[i]
		if ('0' <= c && c <= '9') || ('a' <= c && c <= 'f') {
			continue
		}
		return false
	}
	return true
}

func decodePath(path string) callpath {
	var segs callpath
	for {
		if path == "" {
			break
		}
		start, rest := path[:1], path[1:]
		if start != "." {
			return nil
		}
		i := strings.Index(rest, ".")
		if i < 0 {
			i = len(rest)
		}
		rest, path = rest[:i], rest[i:]
		n, err := strconv.ParseUint(rest, 10, 32)
		if err != nil {
			return nil
		}
		segs = append(segs, uint32(n))
	}
	return segs
}

func decodeSpan(span string) (string, callpath) {
	const txidLen = 32
	if len(span) < txidLen {
		return "", nil
	}
	txid, p := span[:txidLen], span[txidLen:]
	if !validTxid(txid) {
		return "", nil
	}
	path := decodePath(p)
	if len(p) > 0 && len(path) == 0 {
		return "", nil
	}
	return txid, path
}

func findCall(tx *api.Transaction, path callpath) *api.Call {
	call := tx.GetRoot()
	for {
		if call == nil {
			return nil
		}
		if len(call.Path) > len(path) {
			return nil
		}
		for i := range call.Path {
			if call.Path[i] != path[i] {
				return nil
			}
		}
		if len(call.Path) == len(path) {
			return call
		}

		for _, sub := range call.Subcalls {
			if l := len(call.Path); len(sub.Path) > l && sub.Path[l] == path[l] {
				call = sub
				break
			}
		}
	}
}
