package main

import (
	"bytes"
	"crypto/sha256"
	"encoding/binary"
	"encoding/hex"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"os"
	"path"
	"runtime"
	"sort"
	"strconv"
	"sync"
	"sync/atomic"
	"time"

	"code.justin.tv/release/trace/analysis/tx"
	"code.justin.tv/release/trace/api"
	"code.justin.tv/release/trace/scanproto"
	"code.justin.tv/rhys/nursery/trace/txstore"
	"github.com/golang/protobuf/proto"
	"github.com/syndtr/goleveldb/leveldb/table"
)

func main() {
	infile := flag.String("input", "/dev/stdin", "Input filename")
	hostport := flag.String("http", "127.0.0.1:8080", "HTTP debug host and port")
	dbdir := flag.String("dbdir", "", "Directory for LevelDB SSTables")

	flag.Parse()

	log.SetFlags(log.LUTC | log.Ldate | log.Ltime)

	go func() {
		err := (&http.Server{
			Addr: *hostport,
		}).ListenAndServe()
		if err != nil {
			log.Fatalf("http err=%q", err)
		}
	}()

	f, err := os.Open(*infile)
	if err != nil {
		log.Fatalf("Open err=%q", err)
	}
	defer func(c io.Closer) {
		err := c.Close()
		if err != nil {
			log.Fatalf("Close err=%q", err)
		}
	}(f)

	src := txstore.NewReaderSource(f, runtime.NumCPU())

	defer src.Stop()
	go src.Run()

	// accessed atomically
	var fileCount int64 // 13500395

	th := &thing{
		writeDataFn: func(data records) error {
			var prev []byte
			var seq uint64
			for i, r := range data {
				const (
					keyDel = 0
					keyVal = 1
				)

				if bytes.Equal(r.key, prev) {
					seq++
				} else {
					seq = 0
				}
				prev = r.key

				leveldbKey := make([]byte, len(r.key)+8)
				copy(leveldbKey, r.key)
				binary.LittleEndian.PutUint64(leveldbKey[len(r.key):], (seq<<8)|keyVal)

				r.key = leveldbKey
				data[i] = r
			}

			sort.Sort(data)

			var f io.Writer = ioutil.Discard

			n := atomic.AddInt64(&fileCount, 1) - 1

			file, err := os.OpenFile(path.Join(*dbdir, fmt.Sprintf("%06d.ldb", n)),
				os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
			if err != nil {
				return err
			}
			defer func() {
				if file == nil {
					return
				}
				log.Printf("file not closed")
				err := file.Close()
				if err != nil {
					log.Printf("Close err=%q", err)
				}

			}()
			f = file

			w := table.NewWriter(f, nil)

			for _, r := range data {
				err := w.Append(r.key, r.value)
				if err != nil {
					if cerr := file.Close(); cerr != nil {
					}
					file = nil
					return err
				}
			}

			err = w.Close()
			if err != nil {
				if cerr := file.Close(); cerr != nil {
				}
				file = nil
				return err
			}

			if cerr := file.Close(); cerr != nil {
				return cerr
			}
			file = nil

			return nil
		},
	}

	var wg sync.WaitGroup
	for i := 0; i < runtime.NumCPU(); i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			th.consume(src)
		}()
	}

	wg.Wait()
}

type thing struct {
	writeDataFn func(data records) error
}

func (th *thing) writeData(data records) error {
	if th.writeDataFn == nil {
		return nil
	}
	return th.writeDataFn(data)
}

func (th *thing) consume(src scanproto.TransactionSource) {
	txs := src.Transactions()
	errs := src.Errors()

	var recs records
	var i, n int64
	for txs != nil || errs != nil {
		const lim = 100000
		if len(recs) > 2*lim {
			var data records
			data, recs = recs[:lim], recs[lim:]
			err := th.writeData(data)
			if err != nil {
				log.Printf("writeData err=%q", err)
			}
		}

		var lg bool

		select {
		case tx, ok := <-txs:
			if !ok {
				txs = nil
				break
			}
			i++
			if i > n {
				n += n + 1
				lg = true
			}
			if lg {
				name := ""
				if svc := tx.GetRoot().GetSvc(); svc != nil {
					name = svc.Name
				}
				log.Printf("i=%d svc=%q", i, name)
			}

			buf, err := proto.Marshal(tx)
			if err != nil {
				break
			}

			txid := IDForTx(tx)
			// TODO: consider using a cheaper hash function
			hash := sha256.Sum256(buf)

			recs = append(recs, record{
				key: []byte(fmt.Sprintf("txid=%q sha256=%q",
					txid, hex.EncodeToString(hash[:]))),
				value: buf,
			})
			recs = appendCalls(recs, txid, tx.GetRoot())
		case err, ok := <-errs:
			if !ok {
				errs = nil
				break
			}
			log.Printf("consume err=%q", err)
		}
	}

	if false {
		log.Printf("i=%d len=%d", i, len(recs))
		sort.Sort(recs)
		log.Printf("sorted i=%d len=%d", i, len(recs))

		if len(recs) >= 1 {
			r := recs[0]
			log.Printf("val=%d key %s", len(r.value), r.key)
		}
		for _, r := range recs {
			if bytes.Contains(r.key, []byte("code")) {
				log.Printf("val=%d key %s", len(r.value), r.key)
				break
			}
		}
		if len(recs) >= 2 {
			r := recs[len(recs)-1]
			log.Printf("val=%d key %s", len(r.value), r.key)
		}
	}

	err := th.writeData(recs)
	if err != nil {
		log.Printf("writeData err=%q", err)
	}
}

func appendCalls(recs []record, txid tx.TransactionID, call *api.Call) []record {
	if call == nil {
		return recs
	}
	for _, sub := range call.Subcalls {
		recs = appendCalls(recs, txid, sub)
	}

	var path []byte
	for _, p := range call.Path {
		path = append(path, fmt.Sprintf(".%d", p)...)
	}

	var svcname, host string
	var pid int32
	if svc := call.Svc; svc != nil {
		svcname = svc.Name
		host = svc.Host
		pid = svc.Pid
	}

	var t string
	// TODO: optimize this hot spot
	if st := call.ServiceTimestamps; len(st) > 0 {
		t = time.Unix(0, st[0].Time).UTC().Format("2006-01-02T15:04:05.000000000Z07:00")
	} else if ct := call.ClientTimestamps; len(ct) > 0 {
		t = time.Unix(0, ct[0].Time).UTC().Format("2006-01-02T15:04:05.000000000Z07:00")
	}

	if t != "" {
		// TODO: further optimize this hot spot

		var key []byte
		key = append(key, "index"...)
		key = strconv.AppendQuote(append(key, " svcname="...), svcname)
		if false {
			key = strconv.AppendQuote(append(key, " host="...), host)
			key = strconv.AppendInt(append(key, " pid="...), int64(pid), 10)
		}
		key = strconv.AppendQuote(append(key, " time="...), t)
		key = strconv.AppendQuote(append(key, " txid="...), txid.String())
		key = strconv.AppendQuote(append(key, " path="...), string(path))
		recs = append(recs, record{
			key:   key,
			value: nil,
		})
	}

	return recs
}

type record struct {
	key   []byte
	value []byte
}

type records []record

func (r records) Len() int           { return len(r) }
func (r records) Swap(i, j int)      { r[i], r[j] = r[j], r[i] }
func (r records) Less(i, j int) bool { return bytes.Compare(r[i].key, r[j].key) < 0 }

func IDForTx(t *api.Transaction) tx.TransactionID {
	txid := t.TransactionId
	switch len(txid) {
	case 0:
		return tx.TransactionID{0, 0}
	case 1:
		return tx.TransactionID{0, t.TransactionId[0]}
	default:
		return tx.TransactionID{t.TransactionId[0], t.TransactionId[1]}

	}
}
