package xarth

import (
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"net/http/httptest"
	"os"
	"sync"
	"testing"
	"time"
)

func TestReservedHeader(t *testing.T) {
	for _, tt := range []struct {
		desc     string
		name     string
		reserved bool
	}{
		{name: "chitin", reserved: false, desc: "Headers must be normalized before calling reservedHeader"},
		{name: "Chitin", reserved: true, desc: "The prefix is reserved"},
		{name: "Chitin-Fu", reserved: true, desc: "Headers beginning with the prefix are reserved"},
		{name: "Chitintacular", reserved: false, desc: "The prefix check tokenizes on '-'"},
		{name: "Trace-Span", reserved: false, desc: "We don't reserve every header we use"},
	} {
		if have, want := reservedHeader(tt.name), tt.reserved; have != want {
			t.Errorf("reservedHeader(%q); %t != %t (%s)", tt.name, have, want, tt.desc)
		}
	}
}

func TestContextRetrieval(t *testing.T) {
	resp, err := testGet(t, func(
		w http.ResponseWriter, r *http.Request) {
		fmt.Fprintf(w, "hi")

		ctx, ok := Context(w, r)
		if !ok {
			t.Errorf("could not find context")
		}

		err := ctx.Err()
		if err != nil {
			t.Errorf("context is no longer active: %v", err)
		}

		arrived := getArrivalTime(ctx)
		now := time.Now()
		if delta := now.Sub(arrived); delta < 0 || delta > 1*time.Second {
			t.Errorf("improbable delta: %v", delta)
		}
	}, "/")

	if err != nil {
		t.Errorf("unexpected error: %v", err)
		return
	}
	defer resp.Body.Close()
}

func TestPanic(t *testing.T) {
	resp, err := testGet(t, func(w http.ResponseWriter, r *http.Request) {
		panic("derp")
	}, "/")
	if err == nil {
		t.Errorf("panic in a handler should lead to an error on the client")
		resp.Body.Close()
		return
	}

	resp, err = testGet(t, func(w http.ResponseWriter, r *http.Request) {
		panic(nil)
	}, "/")
	if err == nil {
		t.Errorf("nil panic in a handler should lead to an error on the client")
		resp.Body.Close()
		return
	}

	resp, err = testGet(t, func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		w.(http.Flusher).Flush()
		panic("derp")
	}, "/")
	if err != nil {
		t.Errorf("panic in a handler that has flushed its response need not "+
			"result in client error ... but it did: %v", err)
		return
	}
	resp.Body.Close()
}

func TestStatusCode(t *testing.T) {
	maker := func(status int) func(w http.ResponseWriter, r *http.Request) {
		return func(w http.ResponseWriter, r *http.Request) {
			w.WriteHeader(status)
		}
	}

	mux := http.NewServeMux()
	mux.Handle("/", http.DefaultServeMux)
	mux.HandleFunc("/notfound", maker(http.StatusNotFound))
	mux.HandleFunc("/ok", maker(http.StatusOK))
	mux.HandleFunc("/error", maker(http.StatusInternalServerError))
	mux.HandleFunc("/teapot", maker(http.StatusTeapot))

	for _, tt := range []struct {
		url    string
		status int
	}{
		{url: "/notfound", status: http.StatusNotFound},
		{url: "/ok", status: http.StatusOK},
		{url: "/error", status: http.StatusInternalServerError},
		{url: "/teapot", status: http.StatusTeapot},

		{url: "/debug", status: http.StatusNotFound},
		{url: "/debug/", status: http.StatusNotFound},
		{url: "/debug/vars/", status: http.StatusNotFound},

		{url: "/debug/vars", status: http.StatusOK},
		{url: "/debug/pprof", status: http.StatusOK},
		{url: "/debug/pprof/", status: http.StatusOK},
	} {
		resp, err := testGet(t, mux.ServeHTTP, tt.url)
		if err != nil {
			t.Errorf("unexpected error: %v", err)
			continue
		}
		resp.Body.Close()
		if have, want := resp.StatusCode, tt.status; have != want {
			t.Errorf("testGet(%q) status; %d != %d", tt.url, have, want)
		}
	}
}

func testGet(t *testing.T, handlerFunc http.HandlerFunc, url string) (*http.Response, error) {
	var wg sync.WaitGroup

	srv := httptest.NewUnstartedServer(http.HandlerFunc(func(
		w http.ResponseWriter, r *http.Request) {
		wg.Add(1)
		defer wg.Done()
		handlerFunc(w, r)
	}))

	srv.Config = wrapServer(srv.Config)
	srv.Listener = wrapListener(srv.Listener)

	srv.Start()
	defer srv.Close()

	wg.Add(1)
	resp, err := http.Get(srv.URL + url)
	if err != nil {
		return resp, err
	}
	wg.Done()
	wg.Wait()
	return resp, nil
}

func BenchmarkStdlibServer(b *testing.B) {
	benchStdlibServer(b, func(w http.ResponseWriter, r *http.Request) {
	})
}

func BenchmarkStdlibServerHello(b *testing.B) {
	benchStdlibServer(b, func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		fmt.Fprintf(w, "hello %q", r.URL.String())
	})
}

func BenchmarkChitinServer(b *testing.B) {
	benchChitinServer(b, func(w http.ResponseWriter, r *http.Request) {
	})
}

func BenchmarkChitinServerHello(b *testing.B) {
	benchChitinServer(b, func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		fmt.Fprintf(w, "hello %q", r.URL.String())
	})
}

func BenchmarkChitinServerContextLookup(b *testing.B) {
	benchChitinServer(b, func(w http.ResponseWriter, r *http.Request) {
		Context(w, r)
	})
}

func benchStdlibServer(b *testing.B, fn http.HandlerFunc) {
	srv := httptest.NewUnstartedServer(fn)
	benchServer(b, srv)
}

func benchChitinServer(b *testing.B, fn http.HandlerFunc) {
	srv := httptest.NewUnstartedServer(fn)
	srv.Config = wrapServer(srv.Config)
	srv.Listener = wrapListener(srv.Listener)
	benchServer(b, srv)
}

func benchServer(b *testing.B, srv *httptest.Server) {
	log.SetOutput(ioutil.Discard)
	defer log.SetOutput(os.Stderr)

	srv.Start()
	defer srv.Close()

	for i := 0; i < b.N; i++ {
		resp, err := http.Get(srv.URL)
		if err != nil {
			b.Errorf("request error: %v", err)
			return
		}
		io.Copy(ioutil.Discard, resp.Body)
		resp.Body.Close()
	}
}
