package main

import (
	"context"
	"flag"
	"fmt"
	"io"
	"log"
	"os"
	"sort"
	"sync"

	"code.justin.tv/release/trace/api"
	"code.justin.tv/release/trace/scanproto"
	"github.com/golang/protobuf/proto"
)

func main() {
	inFile := flag.String("input", "/dev/null", "Name of input file containing TransactionSet")
	limit := flag.Int64("limit", 0, "Limit number of transactions processed (0 to disable)")
	printSvcs := flag.Bool("list-services", false, "Print list of services")
	printProcs := flag.Bool("list-processes", false, "Print list of processes")
	printTimes := flag.Bool("print-times", false, "Print call start and end times")
	filterSvc := flag.String("service", "", "Report on only the specified service")
	flag.Parse()

	f, err := os.Open(*inFile)
	if err != nil {
		log.Fatalf("open: %v", err)
	}
	defer f.Close()

	ctx := context.Background()
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	txs := unmarshal(ctx, f)

	includeCaller := func(caller *api.Service, call *api.Call) bool { return true }
	includeCallee := func(caller *api.Service, call *api.Call) bool { return true }
	if *filterSvc != "" {
		includeCaller = func(caller *api.Service, call *api.Call) bool {
			return caller.GetName() == *filterSvc
		}
		includeCallee = func(caller *api.Service, call *api.Call) bool {
			return call.GetSvc().GetName() == *filterSvc
		}
	}

	var (
		limitCount int64
		txCount    int64
		callCount  int64
		svcCount   = make(map[string]int64)
		procCount  = make(map[string]int64)
	)
	for tx := range txs {
		limitCount++
		if l := *limit; l > 0 && limitCount > l {
			cancel()
			break
		}

		var includeTx bool
		if includeCaller(tx.GetClient(), tx.GetRoot()) {
			includeTx = true
			if svc := tx.GetClient(); svc.GetName() != "" {
				svcCount[svc.GetName()]++
				procCount[fmt.Sprintf("%s/%d", svc.GetHost(), svc.GetPid())]++
			}
			if *printTimes {
				start, end := api.ClientStart(tx.GetRoot()), api.ClientEnd(tx.GetRoot())
				if !start.IsZero() && !end.IsZero() {
					fmt.Printf("%d %d\n", start.UnixNano(), end.UnixNano())
				}
			}
		}
		eachCall(tx.GetClient(), tx.GetRoot(), func(caller *api.Service, call *api.Call) {
			if includeCallee(caller, call) {
				includeTx = true
				callCount++
				if svc := call.GetSvc(); svc.GetName() != "" {
					svcCount[svc.GetName()]++
					procCount[fmt.Sprintf("%s/%d", svc.GetHost(), svc.GetPid())]++
				}
				if *printTimes {
					start, end := api.ServiceStart(tx.GetRoot()), api.ServiceEnd(tx.GetRoot())
					if !start.IsZero() && !end.IsZero() {
						fmt.Printf("%d %d\n", start.UnixNano(), end.UnixNano())
					}
				}
			}
		})
		if includeTx {
			txCount++
		}
	}

	log.Printf("transactions=%d", txCount)
	log.Printf("calls=%d", callCount)
	log.Printf("services=%d", len(svcCount))
	log.Printf("processes=%d", len(procCount))
	if *printSvcs {
		var svcs []string
		for svc := range svcCount {
			svcs = append(svcs, svc)
		}
		sort.Strings(svcs)
		for _, svc := range svcs {
			log.Printf("service=%q count=%d", svc, svcCount[svc])
		}
	}
	_ = *printProcs
}

func eachCall(caller *api.Service, call *api.Call, fn func(caller *api.Service, call *api.Call)) {
	if call == nil {
		return
	}

	fn(caller, call)
	svc := call.GetSvc()
	for _, sub := range call.GetSubcalls() {
		eachCall(svc, sub, fn)
	}
}

func unmarshal(ctx context.Context, r io.Reader) chan *api.Transaction {

	const unmarshalWorkers = 16

	sets := make(chan []byte, 100)
	txs := make(chan *api.Transaction, 100)

	var wg sync.WaitGroup
	for i := 0; i < unmarshalWorkers; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			for buf := range sets {
				var set api.TransactionSet
				err := proto.Unmarshal(buf, &set)
				if err != nil {
					continue
				}
				for _, tx := range set.Transaction {
					select {
					case txs <- tx:
					case <-ctx.Done():
						return
					}
				}
			}
		}()
	}

	go func() {
		wg.Wait()
		close(txs)
	}()

	go func() {
		// This scanner is nominally for EventSet inputs, but it will work
		// correctly for TransactionSet inputs as well since the two message types
		// share structure.
		sc := scanproto.NewEventSetScanner(r)
		sc.Buffer(nil, 1<<20)

		for sc.Scan() {
			buf := sc.Bytes()
			select {
			case sets <- append([]byte(nil), buf...):
			case <-ctx.Done():
				return
			}
		}

		err := sc.Err()
		if err != nil {
			log.Fatalf("scan err=%q", err)
		}

		close(sets)
	}()

	return txs
}
