/*

Command barrel shoots trace events at the remote collection daemon.  It runs
on the same host as instrumented processes and receives events emitted by
processes as UDP packets, converting them into a TCP stream for the long
journey across the network.

*/
package main

import (
	"bufio"
	"expvar"
	"flag"
	"fmt"
	"hash/fnv"
	"io/ioutil"
	"log"
	"math/rand"
	"net"
	"net/http"
	"os"
	"runtime"
	"sync/atomic"
	"time"

	"code.justin.tv/common/chitin"
	"code.justin.tv/release/trace/events"
	// "code.justin.tv/common/golibs/jitter" // TODO: enable jitter
	"code.justin.tv/common/golibs/pkgpath"
	"github.com/golang/protobuf/proto"
	"golang.org/x/net/context"
)

const (
	// Maximum possible size of an IP packet, as an upper bound for the
	// maximum size of a UDP datagram
	maxIP = 1 << 16
	// Size of memory buffer to allocate for receiving UDP datagrams, to
	// amortize allocation requests.  This is allowed as long as we don't
	// maintain long-lived references to any datagrams.
	slabSize = 1 << 20
	// Queue size for datagrams that have not yet been verified as valid
	// protobuf
	pktQueue   = 1 << 10
	pktWorkers = 1

	flushInterval = 1 * time.Second

	defaultUDPHostport  = ":8943"
	defaultHTTPHostport = ":3498"

	httpTimeout = 5 * time.Minute

	workerLifetime = 5 * time.Minute

	targetDNSName = "trace-collect.prod.us-west2.justin.tv:11143"
)

var (
	dialBackoff   = [2]time.Duration{5 * time.Millisecond, 10 * time.Second}
	acceptBackoff = [2]time.Duration{5 * time.Millisecond, 1 * time.Second}

	dialAttemptsTimeout = 5 * time.Second
)

var exp expdata

func main() {
	var udphostport string
	flag.StringVar(&udphostport, "udp", defaultUDPHostport, "UDP address on which to listen.")
	var httphostport string
	flag.StringVar(&httphostport, "http", defaultHTTPHostport, "HTTP address on which to listen.")
	var target string
	flag.StringVar(&target, "target", targetDNSName, "TCP address to receive our proxied events.")
	flag.Parse()

	if os.Getenv("GOMAXPROCS") == "" {
		// Retain pre-go1.5 behavior of defaulting to GOMAXPROCS=1, unless
		// otherwise specified.
		runtime.GOMAXPROCS(1)
	}

	seed := fnv.New64a()
	fmt.Fprintf(seed, "%d\n", os.Getpid())
	fmt.Fprintf(seed, "%s\n", time.Now().UTC().Format(time.RFC3339Nano))
	rand.Seed(int64(seed.Sum64()))

	// We'll listen on UDP for Trace event datagrams from instrumented apps
	udpSock, err := net.ListenPacket("udp", udphostport)
	if err != nil {
		log.Fatalf("listen udp err=%q", err)
	}

	// We serve debug information about this process over HTTP on the same
	// port number that's used for inbound UDP datagrams.  Profiling is
	// available via package net/http/pprof at /debug/pprof/, and operational
	// variables are available via package expvar at /debug/vars.
	tcpSock, err := net.Listen("tcp", httphostport)
	if err != nil {
		log.Fatalf("listen tcp err=%q", err)
	}

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	ctx, err = chitin.ExperimentalTraceContext(ctx)
	if err != nil {
		log.Fatalf("chitin trace-context err=%q", err)
	}

	pkg, _ := pkgpath.Caller(0)
	expvar.Publish(pkg, expvar.Func(func() interface{} { return exp.dup() }))

	srv := &http.Server{
		Handler: chitin.Handler(http.DefaultServeMux, chitin.SetBaseContext(ctx)),

		ReadTimeout:  httpTimeout,
		WriteTimeout: httpTimeout,
	}
	go func() {
		err := srv.Serve(tcpSock)
		if err != nil {
			log.Fatalf("serve http err=%q", err)
		}
	}()

	log.Printf("listening udp addr=%q", udphostport)
	log.Printf("listening http addr=%q", httphostport)

	pkts := make(chan []byte, pktQueue)

	disp := &dispatcher{
		target:  target,
		workers: make([]*worker, pktWorkers),
		quit:    make(chan struct{}),
	}
	go disp.run(ctx, pkts)
	defer close(disp.quit)

	err = accept(pkts, udpSock)
	if err != nil {
		log.Fatalf("read udp err=%q", err)
	}
}

type dispatcher struct {
	target  string
	workers []*worker
	quit    chan struct{}
}

func (d *dispatcher) run(ctx context.Context, pkts <-chan []byte) {
	defer func() {
		for i := range d.workers {
			if w := d.workers[i]; w != nil {
				close(w.quit)
				d.workers[i] = nil
			}
		}
	}()

	for {
		select {
		case <-d.quit:
			return
		case <-ctx.Done():
			return
		default:
		}

		dial := func() (net.Conn, error) {
			// durations from (net/http).DefaultTransport
			return (&net.Dialer{
				Timeout:   30 * time.Second,
				KeepAlive: 30 * time.Second,
			}).Dial("tcp", d.target)
		}

		for i, prev := range d.workers {
			next := &worker{
				in:   pkts,
				dial: dial,
				quit: make(chan struct{}),
			}

			dialCtx, _ := context.WithTimeout(ctx, dialAttemptsTimeout)
			next.connect(dialCtx)

			if next.connected() {
				go next.run(ctx)
				if prev != nil {
					close(prev.quit)
				}
				d.workers[i] = next
			} else {
				log.Printf("worker replacement failure")
			}
		}

		time.Sleep(workerLifetime - time.Duration(rand.Int63n(int64(workerLifetime/2))))
		// TODO: enable jitter package as separate change (including a big Godeps update)
		// <-jitter.NewTimer(jitter.Linear{
		// 	Min: workerLifetime / 2,
		// 	Max: workerLifetime,
		// }).C
	}
}

type worker struct {
	in   <-chan []byte
	dial func() (net.Conn, error)
	quit chan struct{}

	conn net.Conn
	bw   *bufio.Writer
}

func (w *worker) connected() bool {
	return w.conn != nil
}

func (w *worker) connect(ctx context.Context) {
	if w.connected() {
		return
	}

	if w.bw == nil {
		w.bw = bufio.NewWriter(ioutil.Discard)
	}

	var tempDelay time.Duration
	for {
		select {
		case <-w.quit:
			return
		case <-ctx.Done():
			return
		default:
		}

		conn, err := w.dial()
		if err != nil {
			atomic.AddInt64(&exp.DialError, 1)
			log.Printf("tcp dial err=%q", err)
			if ne, ok := err.(net.Error); ok && ne.Temporary() {
				backoff(&tempDelay, dialBackoff[0], dialBackoff[1])
				continue
			}
			// not much we can do but try again later
			time.Sleep(dialBackoff[1])
			continue
		}
		w.conn = conn
		w.bw.Reset(conn)
		return
	}
}

func (w *worker) disconnect() {
	if !w.connected() {
		return
	}

	err := w.bw.Flush()
	if err != nil {
		log.Printf("tcp flush err=%q", err)
	}

	err = w.conn.Close()
	w.conn = nil
	if err != nil {
		log.Printf("tcp close err=%q", err)
	}
}

func (w *worker) reconnect(ctx context.Context, reason error) {
	atomic.AddInt64(&exp.WriteError, 1)
	log.Printf("tcp write err=%q", reason)

	w.disconnect()
	w.connect(ctx)
}

func (w *worker) run(ctx context.Context) {
	flush := time.NewTicker(flushInterval)
	defer flush.Stop()

	w.connect(ctx)
	defer w.disconnect()

	for {
		select {
		case <-w.quit:
			return
		case <-ctx.Done():
			return
		case <-flush.C:
			err := w.bw.Flush()
			if err != nil {
				w.reconnect(ctx, err)
			}
		case pkt := <-w.in:
			// We need to confirm that the datagram is a valid and complete
			// EventSet protobuf message.  EventSet has a single repeated
			// Event field, so concatenating multiple EventSet buffers will
			// result in a single EventSet buffer with even more Event fields.
			// But in order for this to work, we can never send any partial or
			// corrupted data to our downstream collector.
			var es events.EventSet
			err := proto.Unmarshal(pkt, &es)
			if err != nil {
				atomic.AddInt64(&exp.UnmarshalError, 1)
				continue
			}
			atomic.AddInt64(&exp.Events, int64(len(es.Event)))

			_, err = w.bw.Write(pkt)
			if err != nil {
				w.reconnect(ctx, err)
			}
		}
	}
}

func backoff(tempDelay *time.Duration, start, max time.Duration) {
	// backoff algorithm from (*net/http.Server).Serve
	if *tempDelay == 0 {
		*tempDelay = start
	} else {
		*tempDelay *= 2
	}
	if *tempDelay > max {
		*tempDelay = max
	}
	time.Sleep(*tempDelay)
}

func accept(pkts chan<- []byte, uSock net.PacketConn) error {
	var slab, pkt []byte
	var tempDelay time.Duration
	for {
		if len(slab) < maxIP {
			slab = make([]byte, slabSize)
		}

		n, _, err := uSock.ReadFrom(slab)
		if err != nil {
			atomic.AddInt64(&exp.ReadError, 1)
			log.Printf("udp read err=%q", err)
			if ne, ok := err.(net.Error); ok && ne.Temporary() {
				backoff(&tempDelay, acceptBackoff[0], acceptBackoff[1])
				continue
			}
			return err
		}
		tempDelay = 0

		pkt, slab = slab[0:n:n], slab[n:]

		select {
		case pkts <- pkt:
		default:
			atomic.AddInt64(&exp.DecodeBacklogCount, 1)
			atomic.StoreInt64(&exp.DecodeBacklogNow, 1)

			// Block anyway so we don't waste CPU.  Something's going to drop
			// packets, and the OS can do it more efficiently.
			pkts <- pkt

			atomic.StoreInt64(&exp.DecodeBacklogNow, 0)
		}
	}
}
