package instrumentation

import (
	"context"
	"net/http"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"

	"goji.io/pat"

	goji "goji.io"
)

type nullWriter struct{}

func (*nullWriter) Header() http.Header {
	return make(http.Header)
}

func (*nullWriter) Write(a []byte) (int, error) {
	return len(a), nil
}

func (*nullWriter) WriteHeader(status int) {}

type testInstrumentor struct {
	lastName   string
	lastStatus int
	wg         sync.WaitGroup
}

func (t *testInstrumentor) Instrument(ctx context.Context, name string, status int, callTime time.Duration, rate float32) {
	t.lastName = name
	t.lastStatus = status
	t.wg.Done()
}

func returnNoContent(w http.ResponseWriter, r *http.Request)      { w.WriteHeader(http.StatusNoContent) }
func returnStatusConflict(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusConflict) }
func causePanic(w http.ResponseWriter, r *http.Request)           { panic("verify recording") }

var returnExpectationFailed = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusExpectationFailed) })

func TestInstrument(t *testing.T) {
	// sample Mux
	inst := &testInstrumentor{}
	dynamic := func(context.Context) []ContextInstrumenter {
		return []ContextInstrumenter{inst}
	}

	collection := NewInstrumentor(map[Frequency]float32{}, dynamic)

	mux := goji.NewMux()
	mux.Use(collection.Middleware("unknown", Frequent))
	mux.Handle(pat.Get("/override"), InstrumentHandlerFunc("override", Frequent, returnNoContent))
	mux.Handle(pat.Get("/custom"), InstrumentHandler("custom", Rare, returnExpectationFailed))
	mux.HandleFunc(pat.Get("/"), returnStatusConflict)
	mux.Handle(pat.Get("/panic"), InstrumentHandlerFunc("panic", Rare, causePanic))

	// test InstrumentHandlerFunc
	r, _ := http.NewRequest("GET", "/override", nil)
	inst.wg.Add(1)
	mux.ServeHTTP(&nullWriter{}, r)
	inst.wg.Wait()
	assert.Equal(t, "override", inst.lastName)
	assert.Equal(t, http.StatusNoContent, inst.lastStatus)

	// test InstrumentHandler
	r, _ = http.NewRequest("GET", "/custom", nil)
	inst.wg.Add(1)
	mux.ServeHTTP(&nullWriter{}, r)
	inst.wg.Wait()
	assert.Equal(t, "custom", inst.lastName)
	assert.Equal(t, http.StatusExpectationFailed, inst.lastStatus)

	// test panic
	r, _ = http.NewRequest("GET", "/panic", nil)
	inst.wg.Add(1)
	assert.Panics(t, func() { mux.ServeHTTP(&nullWriter{}, r) })
	inst.wg.Wait()
	assert.Equal(t, "panic", inst.lastName)
	assert.Equal(t, 0, inst.lastStatus)

	// test raw route
	r, _ = http.NewRequest("GET", "/", nil)
	inst.wg.Add(1)
	mux.ServeHTTP(&nullWriter{}, r)
	inst.wg.Wait()
	assert.Equal(t, "unknown", inst.lastName)
	assert.Equal(t, http.StatusConflict, inst.lastStatus)
}
