package chitin

import (
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"runtime"
	"strings"
	"time"

	// Install variable export handler, per current good practice
	_ "expvar"
	// Install profiling handlers, per current good practice
	_ "net/http/pprof"
	// Install bininfo data into expvar
	_ "code.justin.tv/common/golibs/bininfo"

	"code.justin.tv/common/chitin/internal/trace"

	"golang.org/x/net/context"
)

var (
	headerPrefix = http.CanonicalHeaderKey("Chitin")
)

func reservedHeader(k string) bool {
	return strings.SplitN(k, "-", 2)[0] == headerPrefix
}

// Context returns the x/net/context.Context corresponding to the
// http.ResponseWriter and *http.Request pair. The Context will be cancelled
// when the handler returns or the client connection is lost (if the
// underlying http.Server generates http.ResponseWriters which implement
// http.CloseNotifier). The Context includes values relating to the incoming
// http request (the "span" in Trace terms) and the initial user request which
// triggered the current span (the "transaction" in Trace terms).
//
// If the incoming request did not include the Trace transaction id, a new one
// is generated.
//
// Context will panic if the http.ResponseWriter provided was not generated by
// chitin. This can be caused by the use of middleware between chitin and the
// user's http.Handler where this function is called, and indicates a bug in
// how chitin is integrated into the app. If chitin is unable to access the
// correct context, the package will be of little use.
//
// In Go 1.7 and newer, this is equivalent to calling
// net/http.Request.Context, with the addition that it will panic if the
// context does not include the values that a chitin Handler would have
// attached.
func Context(w http.ResponseWriter, r *http.Request) context.Context {
	return fetchContext(w, r)
}

func wrapListener(l net.Listener) net.Listener {
	l = listener{
		Listener:        l,
		keepalivePeriod: 3 * time.Minute,
	}
	return l
}

// Serve is a convenient entrypoint for the Context-enhanced http server. It
// wraps the http.Handler with the chitin.Handler function, and enables TCP
// keepalive for the net.Listener (if able).
//
// Its signature is the same as net/http.Serve.
func Serve(l net.Listener, handler http.Handler) error {
	handler = Handler(handler)
	l = wrapListener(l)
	return http.Serve(l, handler)
}

// A HandlerOpt sets options on an http.Handler passed to the Handler
// function, returning a HandlerOpt that can revert the changes.
type HandlerOpt func(*mux) HandlerOpt

// SetLogger returns a HandlerOpt to control the destination and formatting of
// messages generated by the Chitin HTTP framework during the request cycle.
func SetLogger(lg *log.Logger) HandlerOpt {
	return func(h *mux) HandlerOpt {
		prev := h.logger
		h.logger = lg
		return SetLogger(prev)
	}
}

// SetBaseContext returns a HandlerOpt which sets the http.Handler's master
// context.  Every per-request context is derived from this master context, so
// any attached values or cancellation signals will apply to each HTTP
// request.
func SetBaseContext(ctx context.Context) HandlerOpt {
	return func(h *mux) HandlerOpt {
		prev := h.masterCtx
		h.masterCtx = ctx
		return SetBaseContext(prev)
	}
}

// ExperimentalTraceContext adds a value to the supplied parent context,
// causing chitin to emit Trace events when work is done on behalf of the
// context.
//
// The default Trace event destination is 127.0.0.1 on UDP port 8943.
//
// This API is subject to change.
func ExperimentalTraceContext(parent context.Context) (context.Context, error) {
	ctx, err := trace.WithInfo(parent)
	if err != nil {
		return parent, err
	}
	return ctx, nil
}

// Handler adds Chitin support to an http.Handler. It intercepts inbound http
// requests and outbound responses, transforming and instrumenting them.
// Within the user-provided http.Handler, the handler author can gain access
// to an auto-generated context.Context which includes variables required for
// Trace instrumentation and whose lifecycle is attached to the incoming http
// request.
func Handler(handler http.Handler, options ...HandlerOpt) http.Handler {
	if handler == nil {
		handler = http.DefaultServeMux
	}
	h := &mux{
		base:      handler,
		masterCtx: context.Background(),
	}
	for _, opt := range options {
		opt(h)
	}
	return h
}

type mux struct {
	base      http.Handler
	logger    *log.Logger
	masterCtx context.Context
}

func (srv *mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	var midRequest bool

	// Prepare the context
	ctx := markArrivalTime(srv.masterCtx, time.Now())
	if errlog := srv.logger; errlog != nil {
		ctx = setRecorder(ctx, loggerFunc(errlog.Printf))
	}
	ctx = trace.ContextFromHeader(ctx, r.Header)
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	// The context is ready. DO NOT add more values for internal use, some
	// consumers will miss them.

	// Filter and instrument the incoming request
	blessedReq, err := scrubRequest(ctx, r)
	if err != nil {
		printf(ctx, "scrub-request=%q", err)
		w.WriteHeader(getErrorStatus(err))
		return
	}

	// Filter and instrument the outgoing response
	blessedResp := &responseTracer{
		rw:  w,
		ctx: ctx,
	}

	if cn, ok := w.(http.CloseNotifier); ok {
		go watchClose(ctx, cancel, cn.CloseNotify())
	}

	trace.SendRequestHeadReceived(ctx, r)
	printf(ctx, "trace=start")

	defer func() {
		var problem interface{}
		if midRequest {
			problem = recover()
			printf(ctx, "remote-addr=%q panic=%q",
				blessedReq.RemoteAddr,
				fmt.Sprint(problem))

			buf := make([]byte, 64<<10)
			stack := string(buf[:runtime.Stack(buf, false)])
			for _, line := range strings.Split(stack, "\n") {
				printf(ctx, "panic-stack=%q", line)
			}
		}

		if !midRequest && !blessedResp.wroteHeader {
			blessedResp.WriteHeader(http.StatusOK)
		}

		end := time.Now()
		duration := end.Sub(getArrivalTime(ctx))

		// The trace message is sent when we write the header.
		printf(ctx, "trace=end duration=%s bytes=%d",
			micros(duration),
			blessedResp.bytes)

		if midRequest {
			// We've dealt with the problem. We don't need package net/http to
			// print its (inferior) stack trace, so we won't re-panic with the
			// original value. However, we'd like for the user to be
			// disconnected (as the stdlib does). We can't access the net.Conn
			// they came in on here, so we could use panic(nil) to cause the
			// stdlib to do it for us.
			//
			// Users find that change in behavior surprising, particularly if
			// they don't have a good logger set or if they're looking for a
			// specific part of the normal backtrace. This means we need to
			// bubble up the original panic value (so that at least the log
			// they're looking at will contain some sort of panic entry).
			panic(problem)
		}
	}()

	// Show time! Call the user code.
	midRequest = true
	srv.base.ServeHTTP(blessedResp, blessedReq)
	midRequest = false
}

func watchClose(ctx context.Context, cancel context.CancelFunc, ch <-chan bool) {
	select {
	case <-ctx.Done():
	case <-ch:
		if ctx.Err() == nil {
			printf(ctx, "info=%q", "http connection closed")
			cancel()
		}
	}
}

type httpError struct {
	status int
	err    error
}

func (s *httpError) Error() string {
	return fmt.Sprintf("%v (status %d)", s.err, s.status)
}

func getErrorStatus(err error) int {
	if s, ok := err.(*httpError); ok {
		return s.status
	}
	return http.StatusInternalServerError
}

func scrubRequest(ctx context.Context, src *http.Request) (*http.Request, error) {
	dst, ctx := requestWithContext(ctx, src)

	// Copy *net/http.Request fields:

	// We scrub Method

	dst.URL = src.URL

	dst.Proto = src.Proto
	dst.ProtoMajor = src.ProtoMajor
	dst.ProtoMinor = src.ProtoMinor

	// We scrub Header

	// We replace Body

	dst.ContentLength = src.ContentLength

	dst.TransferEncoding = src.TransferEncoding

	dst.Close = src.Close

	dst.Host = src.Host

	// Form, PostForm, and MultipartForm are populated by user calls

	// We don't support Trailer

	// If we want to support haproxy's PROXY protocol, we'd do it by
	// having a custom net.Listener generate a custom net.Conn which
	// decodes the remote address in a separate goroutine before being
	// returned by Accept(). Alternatively, we could have the net.Conn
	// return a unique string as the remote address, which would allow us
	// to better link http requests to the connections on which they
	// arrived.
	dst.RemoteAddr = src.RemoteAddr

	// If we decide to modify the URL, we may want to update this to
	// match.
	dst.RequestURI = src.RequestURI

	// If we start encrypting traffic on our network, we may want to
	// filter access to this field.
	dst.TLS = src.TLS

	switch src.Method {
	default:
		return nil, &httpError{
			err:    fmt.Errorf("bad method: %q", src.Method),
			status: http.StatusMethodNotAllowed,
		}
	case "GET", "POST", "PUT", "HEAD", "DELETE", "PATCH", "OPTIONS":
		dst.Method = src.Method
	}

	dst.Header = make(http.Header, len(src.Header))
	for k, v := range src.Header {
		switch {
		default:
			dst.Header[k] = v
		case k != http.CanonicalHeaderKey(k):
			panic("net/http promises canonicalized headers")
		case reservedHeader(k):
			record(ctx, multiEvent{kind: serverRequestBadHeader, header: k})
		}
	}

	dst.Body = &body{
		rc:  src.Body,
		ctx: ctx,
	}

	return dst, nil
}

type body struct {
	rc    io.ReadCloser
	ctx   context.Context
	bytes int64
}

func (b *body) Read(p []byte) (int, error) {
	n, err := b.rc.Read(p)
	b.bytes += int64(n)
	if err != nil {
		if err == io.EOF {
			now := time.Now()
			arr := getArrivalTime(b.ctx)
			record(b.ctx, multiEvent{kind: serverReadEOF, duration: now.Sub(arr), bytes: b.bytes})
		}
	}
	return n, err
}

func (b *body) Close() error {
	now := time.Now()
	arr := getArrivalTime(b.ctx)
	record(b.ctx, multiEvent{kind: serverRequestBodyClose, duration: now.Sub(arr)})
	err := b.rc.Close()
	return err
}

type responseTracer struct {
	rw          http.ResponseWriter
	ctx         context.Context
	header      http.Header
	wroteHeader bool
	status      int
	bytes       int64
	cn          chan bool
}

var (
	_ http.ResponseWriter = (*responseTracer)(nil)
	_ http.CloseNotifier  = (*responseTracer)(nil)
	_ http.Flusher        = (*responseTracer)(nil)
)

func (tr *responseTracer) Header() http.Header {
	if tr.header == nil {
		tr.header = make(http.Header)
	}
	return tr.header
}

func (tr *responseTracer) WriteHeader(status int) {
	if tr.wroteHeader {
		printf(tr.ctx, "err=%q", "WriteHeader called multiple times")
		return
	}
	tr.wroteHeader = true
	tr.status = status
	if tr.header != nil {
		header := tr.rw.Header()
		for k, v := range tr.header {
			k = http.CanonicalHeaderKey(k)
			switch {
			default:
				for _, value := range v {
					header.Add(k, value)
				}
			case reservedHeader(k):
				record(tr.ctx, multiEvent{kind: serverResponseBadHeader, header: k})
			}
		}
	}

	record(tr.ctx, multiEvent{kind: serverWriteHeader,
		status:   status,
		duration: time.Now().Sub(getArrivalTime(tr.ctx))})

	trace.SendResponseHeadPrepared(tr.ctx, status)
	tr.rw.WriteHeader(status)
}

func (tr *responseTracer) Write(b []byte) (int, error) {
	if !tr.wroteHeader {
		tr.WriteHeader(http.StatusOK)
	}
	n, err := tr.rw.Write(b)
	tr.bytes += int64(n)
	return n, err
}

func (tr *responseTracer) CloseNotify() <-chan bool {
	if tr.cn == nil {
		tr.cn = make(chan bool, 1)
		go tr.watchDone()
	}
	return tr.cn
}

func (tr *responseTracer) watchDone() {
	<-tr.ctx.Done()
	tr.cn <- true
}

func (tr *responseTracer) Flush() {
	if !tr.wroteHeader {
		tr.WriteHeader(http.StatusOK)
	}
	if fl, ok := tr.rw.(http.Flusher); ok {
		fl.Flush()
	}
}

//

var (
	cancelled context.Context
)

func init() {
	ctx, cancel := context.WithCancel(context.Background())
	cancel()
	cancelled = ctx
}
