package xarth

import (
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"runtime"
	"strings"
	"time"

	// Install variable export handler, per current good practice
	_ "expvar"
	// Install profiling handlers, per current good practice
	_ "net/http/pprof"

	"code.justin.tv/rhys/nursery/xarth/internal/trace"

	"golang.org/x/net/context"
)

var (
	headerPrefix = http.CanonicalHeaderKey("Chitin")
)

func reservedHeader(k string) bool {
	return strings.SplitN(k, "-", 2)[0] == headerPrefix
}

// Context returns the x/net/context.Context corresponding to the
// http.ResponseWriter and *http.Request pair. The Context will be cancelled
// when the handler returns or the client connection is lost (if the
// underlying http.Server generates http.ResponseWriters which implement
// http.CloseNotifier). The Context includes values relating to the incoming
// http request (the "span" in Trace terms) and the initial user request which
// triggered the current span (the "transaction" in Trace terms).
//
// If the incoming request did not include the Trace transaction id, a new one
// is generated.
func Context(w http.ResponseWriter, r *http.Request) (context.Context, bool) {
	// TODO: What's the right global func via which to access the context?
	if tr, ok := w.(*responseTracer); ok {
		return tr.ctx, true
	}
	return cancelled, false
}

func wrapServer(srv *http.Server) *http.Server {
	if srv == nil {
		srv = &http.Server{}
	}

	s := &server{
		base: srv,
		mux:  srv.Handler,
	}
	if s.mux == nil {
		s.mux = http.DefaultServeMux
	}

	ns := &http.Server{
		Addr:         srv.Addr,
		Handler:      s,
		ReadTimeout:  srv.ReadTimeout,
		WriteTimeout: srv.WriteTimeout,
	}

	return ns
}

func wrapListener(l net.Listener) net.Listener {
	l = listener{
		Listener:        l,
		keepalivePeriod: 3 * time.Minute,
	}
	return l
}

// Serve is the entrypoint for the Context-enhanced http server. The
// http.Handler of srv is wrapped to allow instrumentation of the
// request/response cycle. Within the user-provided http.Handler, the handler
// author can provide the http.ResponseWriter and *http.Request to the Context
// function to gain access to an auto-generated context.Context, which
// includes variables required for Trace instrumentation and whose lifecycle
// is attached to the incoming http request.
//
// The net.Listener is converted to enable TCP keepalive, if able.
func Serve(srv *http.Server, l net.Listener) error {
	srv = wrapServer(srv)
	l = wrapListener(l)
	return srv.Serve(l)
}

//

type server struct {
	base *http.Server
	mux  http.Handler
}

func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	var midRequest bool

	// Prepare the context
	ctx := markArrivalTime(context.Background(), time.Now())
	if errlog := srv.base.ErrorLog; errlog != nil {
		ctx = setRecorder(ctx, loggerFunc(errlog.Printf))
	} else {
		ctx = setRecorder(ctx, loggerFunc(log.Printf))
	}
	ctx = trace.ContextFromHeader(ctx, r.Header)
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	// The context is ready. DO NOT add more values for internal use, some
	// consumers will miss them.

	// Filter and instrument the incoming request
	blessedReq := &http.Request{}
	err := scrubRequest(ctx, blessedReq, r)
	if err != nil {
		printf(ctx, "scrub-request=%q", err)
		return
	}

	// Filter and instrument the outgoing response
	blessedResp := &responseTracer{
		rw:  w,
		ctx: ctx,
	}

	if cn, ok := w.(http.CloseNotifier); ok {
		go watchClose(ctx, cancel, cn.CloseNotify())
	}

	printf(ctx, "trace=start")

	defer func() {
		if midRequest {
			problem := recover()
			printf(ctx, "remote-addr=%q panic=%q",
				blessedReq.RemoteAddr,
				fmt.Sprint(problem))

			buf := make([]byte, 64<<10)
			stack := string(buf[:runtime.Stack(buf, false)])
			for _, line := range strings.Split(stack, "\n") {
				printf(ctx, "panic-stack=%q", line)
			}
		}

		if !midRequest && !blessedResp.wroteHeader {
			blessedResp.WriteHeader(http.StatusOK)
		}

		end := time.Now()
		duration := end.Sub(getArrivalTime(ctx))

		printf(ctx, "trace=end duration=%s bytes=%d",
			micros(duration),
			blessedResp.bytes)

		if midRequest {
			// We've dealt with the problem. We don't need package net/http to
			// print its (inferior) stack trace, so we won't re-panic with the
			// original value. However, we'd like for the user to be
			// disconnected (as the stdlib does). We can't access the net.Conn
			// they came in on here, so we'll use panic(nil) to cause the
			// stdlib to do it for us.
			panic(nil)
		}
	}()

	// Show time! Call the user code.
	midRequest = true
	srv.mux.ServeHTTP(blessedResp, blessedReq)
	midRequest = false
}

func watchClose(ctx context.Context, cancel context.CancelFunc, ch <-chan bool) {
	select {
	case <-ctx.Done():
	case <-ch:
		if ctx.Err() == nil {
			printf(ctx, "info=%q", "http connection closed")
			cancel()
		}
	}
}

func scrubRequest(ctx context.Context, dst, src *http.Request) error {
	*dst = http.Request{
		// We scrub Method

		URL: src.URL,

		Proto:      src.Proto,
		ProtoMajor: src.ProtoMajor,
		ProtoMinor: src.ProtoMinor,

		// We scrub Header

		// We replace Body

		ContentLength: src.ContentLength,

		TransferEncoding: src.TransferEncoding,

		Close: src.Close,

		Host: src.Host,

		// Form, PostForm, and MultipartForm are populated by user calls

		// We don't support Trailer

		// If we want to support haproxy's PROXY protocol, we'd do it by
		// having a custom net.Listener generate a custom net.Conn which
		// decodes the remote address in a separate goroutine before being
		// returned by Accept(). Alternatively, we could have the net.Conn
		// return a unique string as the remote address, which would allow us
		// to better link http requests to the connections on which they
		// arrived.
		RemoteAddr: src.RemoteAddr,

		// If we decide to modify the URL, we may want to update this to
		// match.
		RequestURI: src.RequestURI,

		// If we start encrypting traffic on our network, we may want to
		// filter access to this field.
		TLS: src.TLS,
	}

	switch src.Method {
	default:
		return fmt.Errorf("bad method: %q", src.Method)
	case "GET", "POST", "PUT", "HEAD", "DELETE":
		dst.Method = src.Method
	}

	dst.Header = make(http.Header, len(src.Header))
	for k, v := range src.Header {
		switch {
		default:
			dst.Header[k] = v
		case k != http.CanonicalHeaderKey(k):
			panic("net/http promises canonicalized headers")
		case reservedHeader(k):
			record(ctx, multiEvent{kind: serverRequestBadHeader, header: k})
		}
	}

	dst.Body = &body{
		rc:  src.Body,
		ctx: ctx,
	}

	return nil
}

type body struct {
	rc    io.ReadCloser
	ctx   context.Context
	bytes int64
}

func (b *body) Read(p []byte) (int, error) {
	n, err := b.rc.Read(p)
	b.bytes += int64(n)
	if err != nil {
		if err == io.EOF {
			now := time.Now()
			arr := getArrivalTime(b.ctx)
			record(b.ctx, multiEvent{kind: serverReadEOF, duration: now.Sub(arr), bytes: b.bytes})
		}
	}
	return n, err
}

func (b *body) Close() error {
	now := time.Now()
	arr := getArrivalTime(b.ctx)
	record(b.ctx, multiEvent{kind: serverRequestBodyClose, duration: now.Sub(arr)})
	err := b.Close()
	return err
}

type responseTracer struct {
	rw          http.ResponseWriter
	ctx         context.Context
	header      http.Header
	wroteHeader bool
	status      int
	bytes       int64
	cn          chan bool
}

var (
	_ http.ResponseWriter = (*responseTracer)(nil)
	_ http.CloseNotifier  = (*responseTracer)(nil)
	_ http.Flusher        = (*responseTracer)(nil)
)

func (tr *responseTracer) Header() http.Header {
	if tr.header == nil {
		tr.header = make(http.Header)
	}
	return tr.header
}

func (tr *responseTracer) WriteHeader(status int) {
	if tr.wroteHeader {
		printf(tr.ctx, "err=%q", "WriteHeader called multiple times")
		return
	}
	tr.wroteHeader = true
	tr.status = status
	if tr.header != nil {
		header := tr.rw.Header()
		for k, v := range tr.header {
			k = http.CanonicalHeaderKey(k)
			switch {
			default:
				for _, value := range v {
					header.Add(k, value)
				}
			case reservedHeader(k):
				record(tr.ctx, multiEvent{kind: serverResponseBadHeader, header: k})
			}
		}
	}

	record(tr.ctx, multiEvent{kind: serverWriteHeader,
		status:   status,
		duration: time.Now().Sub(getArrivalTime(tr.ctx))})

	tr.rw.WriteHeader(status)
}

func (tr *responseTracer) Write(b []byte) (int, error) {
	if !tr.wroteHeader {
		tr.WriteHeader(http.StatusOK)
	}
	n, err := tr.rw.Write(b)
	tr.bytes += int64(n)
	return n, err
}

func (tr *responseTracer) CloseNotify() <-chan bool {
	if tr.cn == nil {
		tr.cn = make(chan bool)
		go tr.watchDone()
	}
	return tr.cn
}

func (tr *responseTracer) watchDone() {
	<-tr.ctx.Done()
	tr.cn <- true
}

func (tr *responseTracer) Flush() {
	if !tr.wroteHeader {
		tr.WriteHeader(http.StatusOK)
	}
	if fl, ok := tr.rw.(http.Flusher); ok {
		fl.Flush()
	}
}

//

var (
	cancelled context.Context
)

func init() {
	ctx, cancel := context.WithCancel(context.Background())
	cancel()
	cancelled = ctx
}
