package twitchclient

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"strings"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"

	"code.justin.tv/chat/timing"
	"code.justin.tv/foundation/twitchclient/mocks"
)

type key int

var testKey = new(key)

func TestNewClient_RequiresHost(t *testing.T) {
	_, err := NewClient(ClientConf{})
	assert.Error(t, err)
}

func TestApplyDefaults(t *testing.T) {
	t.Parallel()

	const (
		positive = 1 * time.Second
		negative = -1 * time.Second
		zero     = 0 * time.Second
	)
	t.Run("IdleConnTimeout", func(t *testing.T) {
		t.Run(fmt.Sprintf("defaults to %s when unset", defaultIdleConnTimeout), func(t *testing.T) {
			conf := ClientConf{}
			applyDefaults(&conf)
			assert.Equal(t, defaultIdleConnTimeout, conf.Transport.IdleConnTimeout)
		})

		t.Run("keeps positive values", func(t *testing.T) {
			conf := ClientConf{Transport: TransportConf{IdleConnTimeout: positive}}
			applyDefaults(&conf)
			assert.Equal(t, positive, conf.Transport.IdleConnTimeout)
		})

		t.Run("coerces negative values to zero to disable timeout", func(t *testing.T) {
			conf := ClientConf{Transport: TransportConf{IdleConnTimeout: negative}}
			applyDefaults(&conf)
			assert.Equal(t, zero, conf.Transport.IdleConnTimeout)
		})
	})

	t.Run("TLSHandshakeTimeout", func(t *testing.T) {
		t.Run(fmt.Sprintf("defaults to %s when unset", defaultTLSHandshakeTimeout), func(t *testing.T) {
			conf := ClientConf{}
			applyDefaults(&conf)
			assert.Equal(t, defaultTLSHandshakeTimeout, conf.Transport.TLSHandshakeTimeout)
		})

		t.Run("keeps positive values", func(t *testing.T) {
			conf := ClientConf{Transport: TransportConf{TLSHandshakeTimeout: positive}}
			applyDefaults(&conf)
			assert.Equal(t, positive, conf.Transport.TLSHandshakeTimeout)
		})

		t.Run("coerces negative values to zero to disable timeout", func(t *testing.T) {
			conf := ClientConf{Transport: TransportConf{TLSHandshakeTimeout: negative}}
			applyDefaults(&conf)
			assert.Equal(t, zero, conf.Transport.TLSHandshakeTimeout)
		})
	})

	t.Run("ExpectContinueTimeout", func(t *testing.T) {
		t.Run(fmt.Sprintf("defaults to %s when unset", defaultTLSHandshakeTimeout), func(t *testing.T) {
			conf := ClientConf{}
			applyDefaults(&conf)
			assert.Equal(t, defaultExpectContinueTimeout, conf.Transport.ExpectContinueTimeout)
		})

		t.Run("keeps positive values", func(t *testing.T) {
			conf := ClientConf{Transport: TransportConf{ExpectContinueTimeout: positive}}
			applyDefaults(&conf)
			assert.Equal(t, positive, conf.Transport.ExpectContinueTimeout)
		})

		t.Run("coerces negative values to zero to disable timeout", func(t *testing.T) {
			conf := ClientConf{Transport: TransportConf{ExpectContinueTimeout: negative}}
			applyDefaults(&conf)
			assert.Equal(t, zero, conf.Transport.ExpectContinueTimeout)
		})
	})
}

func TestNewRequest(t *testing.T) {
	client, err := NewClient(ClientConf{Host: "localhost"})
	assert.NoError(t, err)

	req, err := client.NewRequest("GET", "/path", nil)
	assert.NoError(t, err)

	assert.Equal(t, req.Method, "GET")
	assert.Equal(t, req.URL.Scheme, "http")
	assert.Equal(t, req.URL.Host, "localhost")
	assert.Equal(t, req.URL.Path, "/path")
	assert.Equal(t, req.Body, nil)
}

func TestDo(t *testing.T) {
	mockRT := newMockRT()
	client, err := NewClient(ClientConf{
		Host:          "localhost",
		BaseTransport: mockRT,
	})
	assert.NoError(t, err)

	req, err := client.NewRequest("GET", "/path", nil)
	assert.NoError(t, err)

	t.Run("calls the roundtripper", func(t *testing.T) {
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return &resp200OK, nil
		})

		ctx := context.Background()
		resp, err := client.Do(ctx, req, ReqOpts{})
		assert.NoError(t, err)
		assert.Equal(t, 200, resp.StatusCode)
		assert.True(t, mockRT.RoundTripCalled())
		assert.Nil(t, ctx.Err())
	})

	t.Run("adds standard headers in the roundtripper", func(t *testing.T) {
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			expectedHeader := req.Header.Get(TwitchAuthorizationHeader)
			assert.Equal(t, "faketokendude", expectedHeader)
			expectedHeader = req.Header.Get(TwitchClientRowIDHeader)
			assert.Equal(t, "123456", expectedHeader)
			expectedHeader = req.Header.Get(TwitchClientIDHeader)
			assert.Equal(t, "ABC123", expectedHeader)
			return &resp200OK, nil
		})

		ctx := context.Background()
		resp, err := client.Do(ctx, req, ReqOpts{AuthorizationToken: "faketokendude", ClientRowID: "123456", ClientID: "ABC123"})
		assert.NoError(t, err)
		assert.Equal(t, 200, resp.StatusCode)
		assert.True(t, mockRT.RoundTripCalled())
		assert.Nil(t, ctx.Err())
	})

	t.Run("errors if context deadline passes", func(t *testing.T) {
		timeoutDuration := 1 * time.Microsecond
		sleepDuration := timeoutDuration + 10*time.Millisecond // make sure it has time to trigger the timeout (even if the CPU is busy)

		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			time.Sleep(sleepDuration)
			return &http.Response{}, nil
		})

		ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration)
		defer cancel()
		_, err := client.Do(ctx, req, ReqOpts{})
		assert.NoError(t, err) // NOTE: maybe this should actually return an error ?
		assert.True(t, mockRT.RoundTripCalled())
		assert.Equal(t, "context deadline exceeded", ctx.Err().Error())
	})

	t.Run("set request context if not set", func(t *testing.T) {
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			reqCtx := req.Context()
			assert.NotNil(t, reqCtx.Value(testKey))
			assert.Equal(t, reqCtx.Value(testKey).(string), "testVal")
			return &resp200OK, nil
		})

		ctx := context.WithValue(context.Background(), testKey, "testVal")
		_, err := client.Do(ctx, req, ReqOpts{})
		assert.NoError(t, err)
		assert.True(t, mockRT.RoundTripCalled())
	})
}

func TestStatsRequestTimings(t *testing.T) {
	mockStats := mocks.NewStatter()
	mockRT := newMockRT()
	client, err := NewClient(ClientConf{
		Host:          "localhost",
		Stats:         mockStats,
		BaseTransport: mockRT,
	})
	assert.NoError(t, err)

	req, err := client.NewRequest("GET", "/path", nil)
	assert.NoError(t, err)

	reqOptsWithStat := ReqOpts{
		StatName:       "request_stat_name",
		StatSampleRate: 0.22,
	}

	t.Run(fmt.Sprintf("tracks request timings with common response status codes"), func(t *testing.T) {
		respStatuses := []int{200, 300, 400, 500}
		for status := range respStatuses {
			mockStats.Reset()
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				return &http.Response{StatusCode: status}, nil
			})
			_, err := client.Do(context.Background(), req, reqOptsWithStat)
			assert.NoError(t, err)
			assert.True(t, mockRT.RoundTripCalled())
			assert.Equal(t, 1, mockStats.TimingCounts(fmt.Sprintf("request_stat_name.%d", status)))
			assert.Equal(t, float32(0.22), mockStats.TimingSample(fmt.Sprintf("request_stat_name.%d", status))) // using default sample
		}
	})

	t.Run("tracks request timings with response status 0 when there is an error", func(t *testing.T) {
		mockStats.Reset()
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return nil, errors.New("some error")
		})
		_, err := client.Do(context.Background(), req, reqOptsWithStat)
		assert.Error(t, err)
		assert.True(t, mockRT.RoundTripCalled())
		assert.Equal(t, 1, mockStats.AnyTimingCounts())
		assert.Equal(t, 1, mockStats.TimingCounts("request_stat_name.0"))
	})

	t.Run("does not track request timings if opts.StatName is empty", func(t *testing.T) {
		reqOptsWithNOStat := ReqOpts{}
		mockStats.Reset()
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return &resp200OK, nil
		})
		_, err := client.Do(context.Background(), req, reqOptsWithNOStat) // opts.StatName is empty
		assert.NoError(t, err)
		assert.True(t, mockRT.RoundTripCalled())
		assert.Equal(t, 0, mockStats.AnyTimingCounts())
	})
}

func TestXactTimings(t *testing.T) {
	mockStats := mocks.NewStatter()

	makeRequestWithXact := func(xactName string) {
		xact := &timing.Xact{Stats: mockStats}
		xact.AddName("test")
		xact.Start()
		ctx := timing.XactContext(context.Background(), xact)

		mockRT := newMockRT()
		client, err := NewClient(ClientConf{
			Host:           "localhost",
			TimingXactName: xactName,
			BaseTransport:  mockRT,
		})
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return &resp200OK, nil
		})

		_, err = client.Do(ctx, req, ReqOpts{})
		assert.NoError(t, err)
		assert.True(t, mockRT.RoundTripCalled())
		xact.End("done")
	}

	// Without explicit TimingXactName
	mockStats.Reset()
	makeRequestWithXact("")
	assert.Equal(t, 2, mockStats.AnyTimingCounts())
	assert.Equal(t, 1, mockStats.TimingCounts("test.done.total"))
	assert.Equal(t, 1, mockStats.TimingCounts("test.done.twitchhttp"))

	// With explicit TimingXactName
	mockStats.Reset()
	makeRequestWithXact("someclient")
	assert.Equal(t, 2, mockStats.AnyTimingCounts())
	assert.Equal(t, 1, mockStats.TimingCounts("test.done.total"))
	assert.Equal(t, 1, mockStats.TimingCounts("test.done.someclient"))
}

func TestDoJSON(t *testing.T) {
	mockRT := newMockRT()
	logger := &mocks.Logger{}
	logger.On("Log", mock.Anything, mock.Anything).Return()
	client, err := NewClient(ClientConf{
		Host:          "localhost",
		BaseTransport: mockRT,
		Logger:        logger,
	})
	assert.NoError(t, err)

	type RespData struct {
		Foo string `json:"foo"`
		Coo int    `json:"coo"`
	}

	ctx := context.Background()

	req, err := client.NewRequest("GET", "/path", nil)
	assert.NoError(t, err)

	fakeJSONResp := func(status int, jsonResp string) *http.Response {
		return &http.Response{
			Status:     fmt.Sprintf("%d %s", status, http.StatusText(status)),
			StatusCode: status,
			Body:       ioutil.NopCloser(strings.NewReader(jsonResp)),
			Header:     http.Header{"Content-Type": []string{"application/json"}},
		}
	}

	t.Run("closes the body", func(t *testing.T) {
		body := &mockCloser{bytes.NewBufferString(`{}`), false}
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return &http.Response{Status: "200 OK", StatusCode: 200, Body: body, Header: http.Header{"Content-Type": []string{"application/json"}}}, nil
		})

		var data RespData
		_, err := client.DoJSON(ctx, &data, req, ReqOpts{})
		assert.NoError(t, err)
		assert.True(t, mockRT.RoundTripCalled())
		assert.True(t, body.closed)
	})

	t.Run("successful response", func(t *testing.T) {
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return fakeJSONResp(200, `{"foo": "bar", "coo": 666}`), nil
		})

		var data RespData
		resp, err := client.DoJSON(ctx, &data, req, ReqOpts{})
		assert.NoError(t, err)
		assert.Equal(t, 200, resp.StatusCode)
		assert.Equal(t, "bar", data.Foo)
		assert.Equal(t, 666, data.Coo)
	})

	t.Run("returns parsing error if a successful 2xx response is not valid JSON", func(t *testing.T) {
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return fakeJSONResp(200, `{`), nil
		})

		var data RespData
		resp, err := client.DoJSON(ctx, &data, req, ReqOpts{})
		assert.EqualError(t, err, `Unable to read response body: unexpected EOF`)
		assert.Equal(t, 200, resp.StatusCode)
		_, ok := err.(*Error)
		assert.False(t, ok, "should be a plain error, not a *twitchclient.Error")
	})

	t.Run("ignores content if response status is 204 No Content", func(t *testing.T) {
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return fakeJSONResp(204, `{`), nil
		})

		var data RespData
		resp, err := client.DoJSON(ctx, &data, req, ReqOpts{})
		assert.NoError(t, err)
		assert.Equal(t, 204, resp.StatusCode)
	})

	t.Run("returns error if the http.Client returns an error", func(t *testing.T) {
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return nil, errors.New("entropy underflow lol")
		})

		var data RespData
		_, err := client.DoJSON(ctx, &data, req, ReqOpts{})
		assert.Error(t, err)
		expectedSuffix := "entropy underflow lol"
		if !strings.HasSuffix(err.Error(), expectedSuffix) {
			t.Errorf("Did not find expected error string in http response error\ngot: %s\nwanted a suffix of: %s\n", err, expectedSuffix)
		}
	})

	t.Run("returns *twitchclient.Error with parsed Message if status is 4xx", func(t *testing.T) {
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return fakeJSONResp(400, `{"message":"your bad"}`), nil
		})

		var data RespData
		_, err := client.DoJSON(ctx, &data, req, ReqOpts{})
		assert.EqualError(t, err, `400: your bad`)
		tcErr, ok := err.(*Error)
		assert.True(t, ok, "should be a *twitchclient.Error")
		assert.Equal(t, tcErr.StatusCode, 400)
		assert.Equal(t, tcErr.Message, "your bad")
	})

	t.Run("returns *twitchclient.Error with full body Message if status is 4xx but Content-Type is not json", func(t *testing.T) {
		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			resp := fakeJSONResp(400, `{"message":"your bad"}`)
			resp.Header.Set("Content-Type", "text/plain")
			return resp, nil
		})

		var data RespData
		_, err := client.DoJSON(ctx, &data, req, ReqOpts{})
		assert.EqualError(t, err, `400: {"message":"your bad"}`) // parsed as plain text
		tcErr, ok := err.(*Error)
		assert.True(t, ok, "should be a *twitchclient.Error")
		assert.Equal(t, tcErr.StatusCode, 400)
		assert.Equal(t, tcErr.Message, `{"message":"your bad"}`) // parsed as plain text
	})

	t.Run("returns *twitchclient.Error with parsing error message if 4xx response is unparsable", func(t *testing.T) {
		// NOTE: this behavior may result in security issues because it leaks implementation details.
		// The proper behaviour would be to return a plain parsing error (status 0), so the Edge (Visage)
		// can properly send it to Rollbar instead of returning it as a 4xx error to the public.

		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return fakeJSONResp(404, `{`), nil
		})

		var data RespData
		_, err := client.DoJSON(ctx, &data, req, ReqOpts{})
		assert.EqualError(t, err, `404: Unable to decode JSON response: unexpected EOF`)
		tcErr, ok := err.(*Error)
		assert.True(t, ok, "should be a *twitchclient.Error")
		assert.Equal(t, tcErr.StatusCode, 404)
		assert.Equal(t, tcErr.Message, "Unable to decode JSON response: unexpected EOF")
	})

	t.Run("returns plain error if status is 5xx", func(t *testing.T) {
		// NOTE: this behavior makes handling 5xx errors in the Edge (Visage) very hard, which may result
		// in security issues (by mistakenly leaking internal data). This is because the error
		// message needs to be parsed in order to tell if this error is a hand-made error or a 5xx.
		// The proper behaviour would be to return a *twitchclient.Error just like with 4xx responses.

		mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
			return fakeJSONResp(500, `{"message":"my bad"}`), nil
		})

		var data RespData
		resp, err := client.DoJSON(ctx, &data, req, ReqOpts{})
		assert.EqualError(t, err, `500 Internal Server Error: {"message":"my bad"}`)
		assert.Equal(t, 500, resp.StatusCode)
		_, ok := err.(*Error)
		assert.False(t, ok, "should be a plain error, not a *twitchclient.Error")
	})
}

// Test Helpers
// ------------

// Mock RoundTripper: easy mock neck request and verify it was called
func newMockRT() *mockRT {
	return &mockRT{}
}

type httpHandler func(req *http.Request) (*http.Response, error)

type mockRT struct {
	nextRoundTrip   httpHandler
	roundTripCalled bool // true after stubbed next RoundTrip was done
}

var resp200OK = http.Response{StatusCode: 200}

func (rt *mockRT) OnNextRoundTrip(handler httpHandler) {
	rt.nextRoundTrip = handler
	rt.roundTripCalled = false
}

func (rt *mockRT) RoundTrip(req *http.Request) (*http.Response, error) {
	if rt.nextRoundTrip == nil {
		panic("MockRT can not round trip a request because it has no request handler stub (please call rt.nextRoundTrip(handler) before making the request)")
	}
	resp, err := rt.nextRoundTrip(req)
	rt.roundTripCalled = true
	rt.nextRoundTrip = nil
	return resp, err
}

func (rt *mockRT) RoundTripCalled() bool {
	return rt.roundTripCalled
}

// Mock Closer: to check that a body was closed
type mockCloser struct {
	io.Reader
	closed bool
}

var _ io.ReadCloser = (*mockCloser)(nil)

func (c *mockCloser) Close() error {
	c.closed = true
	return nil
}

// Fake ReadCloser that always fails to Read
type readErrorReadCloser struct{}

func (e *readErrorReadCloser) Read(buffer []byte) (int, error) {
	return 0, errors.New("unexpected EOF")
}

func (e *readErrorReadCloser) Close() error {
	return nil
}

func bodyStr(str string) io.ReadCloser {
	return ioutil.NopCloser(bytes.NewBufferString(str))
}
