package header_test

import (
	"bytes"
	"context"
	"net/http"
	"net/http/httptest"
	"testing"

	"code.justin.tv/feeds/following-service/header"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/suite"
)

var baseHeaders = map[string]string{
	header.ClientIDHeader: "client12",
	header.DeviceIDHeader: "device12345",
}

type HeaderSuite struct {
	suite.Suite
	request *http.Request
}

func TestHeaderSuite(t *testing.T) {
	suite.Run(t, new(HeaderSuite))
}

func (s *HeaderSuite) SetupTest() {
	r, err := http.NewRequest("POST", "http://localhost/", bytes.NewBuffer(nil))
	assert.NoError(s.T(), err)

	s.request = r
}

// test headers that should be present in context
func assertPresentHeaders(t *testing.T, r *http.Request, headers map[string]string) {
	assertHeaders(
		func(w http.ResponseWriter, r *http.Request) {
			for k, v := range headers {
				assert.Equal(t, r.Context().Value(k).(string), v)
			}
		}, r, headers,
	)
}

// test headers that should be missing in context
func assertMissingHeaders(t *testing.T, r *http.Request, headers map[string]string) {
	assertHeaders(
		func(w http.ResponseWriter, r *http.Request) {
			for k := range headers {
				assert.Nil(t, r.Context().Value(k))
			}
		}, r, headers)
}

func assertHeaders(testHandler http.HandlerFunc, r *http.Request, headers map[string]string) {
	for k, v := range headers {
		r.Header.Set(k, v)
	}

	h := header.WithHeaders(testHandler)

	h.ServeHTTP(httptest.NewRecorder(), r)
}

func (s *HeaderSuite) TestAddingSupportedHeaderInContext() {
	assertPresentHeaders(s.T(), s.request, baseHeaders)
}

func (s *HeaderSuite) TestAddingRandomHeaderNotInContext() {
	randomHeader := "random-header-x"
	headers := map[string]string{
		randomHeader: "random-value",
	}

	assertMissingHeaders(s.T(), s.request, headers)
}

func (s *HeaderSuite) TestMissingHeader() {
	assert.Empty(s.T(), header.GetHeaderFromContext(context.Background(), "random-header-x"))

}

func (s *HeaderSuite) TestPresentHeader() {
	ctx := context.Background()
	for k, v := range baseHeaders {
		ctx = context.WithValue(ctx, k, v)
	}

	for k, v := range baseHeaders {
		assert.Equal(s.T(), header.GetHeaderFromContext(ctx, k), v)
	}
}
