package twitchclient

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"strconv"
	"strings"
	"sync"
	"testing"
	"time"

	"code.justin.tv/chat/timing"
	"code.justin.tv/foundation/twitchclient/mocks"

	"context"

	"github.com/cactus/go-statsd-client/statsd"
	. "github.com/smartystreets/goconvey/convey"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
)

type key int

var (
	testKey      = new(key)
	duplicateKey = new(key)
)

func TestNewClient(t *testing.T) {
	Convey("NewClient", t, func() {
		Convey("errors when Host is empty", func() {
			_, err := NewClient(ClientConf{})
			assert.Error(t, err)
		})
	})
}

func TestClientConf(t *testing.T) {
	Convey("ClientConf", t, func() {
		Convey("uses default IdleConnTimeout if none set", func() {
			conf := ClientConf{
				Host: "localhost",
			}
			assert.Equal(t, conf.Transport.idleConnTimeout(), DefaultIdleConnTimeout)
		})

		Convey("allows overriding of IdleConnTimeout", func() {
			d := 90 * time.Second
			tConf := TransportConf{
				IdleConnTimeout: d,
			}
			conf := ClientConf{
				Host:      "localhost",
				Transport: tConf,
			}
			assert.Equal(t, conf.Transport.idleConnTimeout(), d)
		})
	})
}

func TestNewRequest(t *testing.T) {
	Convey("NewRequest", t, func() {
		client, err := NewClient(ClientConf{
			Host: "localhost",
		})
		assert.Nil(t, err)

		Convey("creates a valid request", func() {
			req, err := client.NewRequest("GET", "/path", nil)
			assert.NoError(t, err)

			assert.Equal(t, req.Method, "GET")
			assert.Equal(t, req.URL.Scheme, "http")
			assert.Equal(t, req.URL.Host, "localhost")
			assert.Equal(t, req.URL.Path, "/path")
			assert.Equal(t, req.Body, nil)
		})
	})
}

func TestDo(t *testing.T) {
	Convey("Do", t, func() {
		mockRT := newMockRT()
		elevateKey := new(int)
		client, err := NewClient(ClientConf{
			Host:             "localhost",
			BaseRoundTripper: mockRT,
			ElevateKey:       elevateKey,
		})
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		Convey("should make request to RoundTripper", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				return &resp200OK, nil
			})

			ctx := context.Background()
			resp, err := client.Do(ctx, req, ReqOpts{})
			assert.NoError(t, err)
			assert.Equal(t, 200, resp.StatusCode)
			assert.True(t, mockRT.RoundTripCalled())
			assert.Nil(t, ctx.Err())
		})

		Convey("should add elevate log header if request is elevated", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				elevatedLogHeader := req.Header.Get("X-Ctxlog-Elevate")
				assert.Equal(t, "true", elevatedLogHeader, "This request should have been elevated.")
				return &resp200OK, nil
			})

			ctx := context.Background()
			ctx = context.WithValue(ctx, elevateKey, "true")
			resp, err := client.Do(ctx, req, ReqOpts{AuthorizationToken: "faketokendude", ClientRowID: "12345"})
			assert.NoError(t, err)
			assert.Equal(t, 200, resp.StatusCode)
			assert.True(t, mockRT.RoundTripCalled())
			assert.Nil(t, ctx.Err())
		})

		Convey("should add standard headers in the roundtripper", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				expectedHeader := req.Header.Get(TwitchAuthorizationHeader)
				assert.Equal(t, "faketokendude", expectedHeader)
				expectedHeader = req.Header.Get(TwitchClientRowIDHeader)
				assert.Equal(t, "123456", expectedHeader)
				expectedHeader = req.Header.Get(TwitchClientIDHeader)
				assert.Equal(t, "ABC123", expectedHeader)
				logIDHeader := req.Header.Get("X-Ctxlog-LogID")
				assert.NotEmpty(t, logIDHeader)
				elevatedLogHeader := req.Header.Get("X-Ctxlog-Elevate")
				assert.Empty(t, elevatedLogHeader, "This request should not have been elevated")
				return &resp200OK, nil
			})

			ctx := context.Background()
			resp, err := client.Do(ctx, req, ReqOpts{AuthorizationToken: "faketokendude", ClientRowID: "123456", ClientID: "ABC123"})
			assert.NoError(t, err)
			assert.Equal(t, 200, resp.StatusCode)
			assert.True(t, mockRT.RoundTripCalled())
			assert.Nil(t, ctx.Err())
		})

		Convey("should error if context deadline passes", func() {
			timeoutDuration := 1 * time.Microsecond
			sleepDuration := timeoutDuration + 10*time.Millisecond // make sure it has time to trigger the timeout (even if the CPU is busy)

			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				time.Sleep(sleepDuration)
				return &http.Response{}, nil
			})

			ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration)
			defer cancel()
			_, err := client.Do(ctx, req, ReqOpts{})
			assert.NoError(t, err) // NOTE: maybe this should actually return an error ?
			assert.True(t, mockRT.RoundTripCalled())
			assert.Equal(t, "context deadline exceeded", ctx.Err().Error())
		})

		Convey("should set request context if not set", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				reqCtx := req.Context()
				assert.NotNil(t, reqCtx.Value(testKey))
				assert.Equal(t, reqCtx.Value(testKey).(string), "testVal")
				return &resp200OK, nil
			})

			ctx := context.WithValue(context.Background(), testKey, "testVal")
			_, err := client.Do(ctx, req, ReqOpts{})
			assert.NoError(t, err)
			assert.True(t, mockRT.RoundTripCalled())
		})

		Convey("chitin will merge ctx param and req.Context() if both are being set", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				reqCtx := req.Context()
				assert.NotNil(t, reqCtx.Value(testKey))
				assert.NotNil(t, reqCtx.Value(duplicateKey))
				assert.Equal(t, "ctxVal", reqCtx.Value(testKey).(string))
				assert.Equal(t, "reqVal", reqCtx.Value(duplicateKey).(string)) // both values are present in the same, merged context
				return &resp200OK, nil
			})

			ctx := context.WithValue(context.Background(), testKey, "ctxVal")
			reqCtx := context.WithValue(context.Background(), duplicateKey, "reqVal")
			req = req.WithContext(reqCtx)
			_, err := client.Do(ctx, req, ReqOpts{})
			assert.NoError(t, err)
			assert.True(t, mockRT.RoundTripCalled())
		})
	})
}

func TestStatsRequestTimings(t *testing.T) {
	Convey("Do", t, func() {
		mockStats := newMockStatter()
		mockRT := newMockRT()
		client, err := NewClient(ClientConf{
			Host:             "localhost",
			Stats:            mockStats,
			BaseRoundTripper: mockRT,
		})
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		Convey("with opts.StatName", func() {
			reqOptsWithStat := ReqOpts{StatName: "request_stat_name"}

			respStatuses := []int{200, 300, 400, 500}
			for status := range respStatuses {
				Convey(fmt.Sprintf("should track request timings with response status %d", status), func() {
					mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
						return &http.Response{StatusCode: status}, nil
					})
					_, err := client.Do(context.Background(), req, reqOptsWithStat)
					assert.NoError(t, err)
					assert.True(t, mockRT.RoundTripCalled())
					assert.Equal(t, 1, mockStats.AnyTimingCounts())
					assert.Equal(t, 1, mockStats.TimingCounts(fmt.Sprintf("request_stat_name.%d", status)))
				})
			}

			Convey("should track request timings with response status 0 when there is an error", func() {
				mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
					return nil, errors.New("some error")
				})
				_, err := client.Do(context.Background(), req, reqOptsWithStat)
				assert.Error(t, err)
				assert.True(t, mockRT.RoundTripCalled())
				assert.Equal(t, 1, mockStats.AnyTimingCounts())
				assert.Equal(t, 1, mockStats.TimingCounts("request_stat_name.0"))
			})
		})
		Convey("without opts.StatName", func() {
			reqOptsWithNOStat := ReqOpts{}

			Convey("should not track request timings", func() {
				mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
					return &resp200OK, nil
				})
				_, err := client.Do(context.Background(), req, reqOptsWithNOStat) // opts.StatName is empty
				assert.NoError(t, err)
				assert.True(t, mockRT.RoundTripCalled())
				assert.Equal(t, 0, mockStats.AnyTimingCounts())
			})
		})
	})
}

func TestDialerWithDNSStats(t *testing.T) {
	Convey("dialerWithDNSStats", t, func() {
		ipv4 := "192.168.1.1"
		ipv6 := "2001:4860:0:2001::68"
		ips := []net.IP{net.ParseIP(ipv4), net.ParseIP(ipv6)}

		dialerWrap := func(stats statsd.Statter, dialer ctxDialer) *dialWrap {
			return &dialWrap{
				dialer:   dialer,
				lookupIP: func(host string) ([]net.IP, error) { return ips, nil },

				stats:               stats,
				statPrefix:          "dns",
				enableExtraDNSStats: false,
			}
		}

		Convey("dial IP address", func() {
			mockStats := newMockStatter()
			fakeDialer := &fakeCtxDialer{}
			dialer := dialerWrap(mockStats, fakeDialer)

			_, err := dialer.DialContext(context.Background(), "tcp", "10.12.0.12:80")
			assert.NoError(t, err)
			assert.Equal(t, fakeDialer.DialedAddrs, 1, "DialContext should have been called once")
			assert.Equal(t, fakeDialer.LastAddress, "10.12.0.12:80")
			assert.Equal(t, mockStats.AnyTimingCounts(), 0)
		})

		Convey("dial DNS host over tcp6", func() {
			mockStats := newMockStatter()
			fakeDialer := &fakeCtxDialer{}
			dialer := dialerWrap(mockStats, fakeDialer)

			_, err := dialer.DialContext(context.Background(), "tcp6", "test.twitch.tv:80")
			assert.NoError(t, err)
			assert.Equal(t, fakeDialer.DialedAddrs, 1, "DialContext should have been called once")
			assert.Contains(t, []string{ipv4 + ":80", "[" + ipv6 + "]:80"}, fakeDialer.LastAddress)

			assert.Equal(t, mockStats.AnyTimingCounts(), 1)
			assert.Equal(t, mockStats.TimingCounts("dns.test_twitch_tv.success"), 1)
		})

		Convey("dial DNS host over tcp4", func() {
			mockStats := newMockStatter()
			fakeDialer := &fakeCtxDialer{}
			dialer := dialerWrap(mockStats, fakeDialer)

			_, err := dialer.DialContext(context.Background(), "tcp4", "test.twitch.tv:80")
			assert.NoError(t, err)
			assert.Equal(t, fakeDialer.DialedAddrs, 1, "DialContext should have been called once")
			assert.Equal(t, fakeDialer.LastAddress, ipv4+":80", "tcp4 should always resolve to the ipv4, never to the ipv6")

			assert.Equal(t, mockStats.AnyTimingCounts(), 1)
			assert.Equal(t, mockStats.TimingCounts("dns.test_twitch_tv.success"), 1)
		})

		Convey("with disabled stats (nil)", func() {
			fakeDialer := &fakeCtxDialer{}
			dialer := dialerWrap(nil, fakeDialer)

			_, err := dialer.DialContext(context.Background(), "tcp4", "test.twitch.tv:80")
			assert.NoError(t, err)
			assert.Equal(t, fakeDialer.DialedAddrs, 1, "DialContext should have been called once")
		})
	})
}

func TestXactTimings(t *testing.T) {
	Convey("Do with Xact in context", t, func() {
		mockStats := newMockStatter()

		xact := &timing.Xact{Stats: mockStats}
		xact.AddName("test")
		xact.Start()

		ctx := timing.XactContext(context.Background(), xact)

		makeRequest := func(xactName string) {
			mockRT := newMockRT()
			client, err := NewClient(ClientConf{
				Host:             "localhost",
				TimingXactName:   xactName,
				BaseRoundTripper: mockRT,
			})
			assert.NoError(t, err)

			req, err := client.NewRequest("GET", "/path", nil)
			assert.NoError(t, err)

			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				return &resp200OK, nil
			})

			_, err = client.Do(ctx, req, ReqOpts{})
			assert.NoError(t, err)
			assert.True(t, mockRT.RoundTripCalled())
			xact.End("done")
		}

		Convey("but without explicit TimingXactName", func() {
			makeRequest("")
			assert.Equal(t, 2, mockStats.AnyTimingCounts())
			assert.Equal(t, 1, mockStats.TimingCounts("test.done.total"))
			assert.Equal(t, 1, mockStats.TimingCounts("test.done.twitchhttp"))
		})
		Convey("with explicit TimingXactName", func() {
			makeRequest("someclient")
			assert.Equal(t, 2, mockStats.AnyTimingCounts())
			assert.Equal(t, 1, mockStats.TimingCounts("test.done.total"))
			assert.Equal(t, 1, mockStats.TimingCounts("test.done.someclient"))
		})
	})
}

func TestDoJSON(t *testing.T) {
	Convey("DoJSON", t, func() {
		mockRT := newMockRT()
		logger := &mocks.Logger{}
		logger.On("Log", mock.Anything, mock.Anything).Return()
		client, err := NewClient(ClientConf{
			Host:             "localhost",
			BaseRoundTripper: mockRT,
			Logger:           logger,
			DimensionKey:     &logger,
		})
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		var data *Error
		var body *mockCloser
		fakeDoJSON := func(statusResp int, jsonResp string, errorResp error) (*http.Response, error) {
			body = &mockCloser{bytes.NewBufferString(jsonResp), false}
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				resp := &http.Response{
					Status:     strconv.Itoa(statusResp),
					StatusCode: statusResp,
					Body:       body,
					Header:     http.Header{"Content-Type": []string{"application/json"}},
				}
				return resp, errorResp
			})
			resp, err := client.DoJSON(context.Background(), &data, req, ReqOpts{})
			assert.True(t, mockRT.RoundTripCalled())
			return resp, err
		}

		Convey("should decode a successful response", func() {
			resp, err := fakeDoJSON(200, `{"message":"foobar"}`, nil)

			assert.NoError(t, err)
			assert.Equal(t, "foobar", data.Message)
			assert.Equal(t, true, body.closed)
			assert.Equal(t, 200, resp.StatusCode)
		})

		Convey("should error if Do errors", func() {
			expectedError := errors.New("error in Do")
			resp, err := fakeDoJSON(0, `{"message":"nope"}`, expectedError)

			assert.Error(t, err)
			assert.Equal(t, false, body.closed)

			assert.Equal(t, (*http.Response)(nil), resp)
		})

		Convey("should error for 500's", func() {
			resp, err := fakeDoJSON(500, `{"message":"my bad"}`, nil)
			assertFailure(t, http.StatusInternalServerError, `500: {"message":"my bad"}`, resp, err)
			assert.Equal(t, true, body.closed)
		})

		Convey("should *Error for parsable 400's", func() {
			resp, err := fakeDoJSON(400, `{"message":"your bad"}`, nil)

			assert.Error(t, err)
			assert.Equal(t, 400, err.(*Error).StatusCode)
			assert.Equal(t, 400, resp.StatusCode)
			assert.Equal(t, "your bad", err.(*Error).Message)
			assert.Equal(t, true, body.closed)
		})

		Convey("should *Error for unparsable 400's", func() {
			resp, err := fakeDoJSON(400, `{`, nil)
			assertFailure(t, http.StatusBadRequest, "Unable to read response body: unexpected EOF", resp, err)
			assert.Equal(t, true, body.closed)
		})

		Convey("should error for unparsable 200's", func() {
			resp, err := fakeDoJSON(200, `{`, nil)

			assert.Error(t, err)
			assert.Equal(t, 200, resp.StatusCode)
			assert.Equal(t, "Unable to read response body: unexpected EOF", err.Error())
			assert.Equal(t, true, body.closed)
		})

		Convey("should not attempt to deserialize if 204 No Content", func() {
			resp, err := fakeDoJSON(204, "", nil)

			assert.NoError(t, err)
			assert.Equal(t, 204, resp.StatusCode)
			assert.Nil(t, data)
			assert.Equal(t, true, body.closed)
		})

		Convey("should log when response body fails to close", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				resp := &http.Response{
					Body:       &closeErrorReadCloser{strings.NewReader("{}")},
					StatusCode: http.StatusOK,
				}
				return resp, nil
			})

			_, err := client.DoJSON(context.Background(), &data, req, ReqOpts{})
			assert.Nil(t, err)
			logger.AssertCalled(t, "Log", mock.Anything)
		})
	})
}

func assertFailure(t *testing.T, statusCode int, message string, resp *http.Response, err error) {
	assert.Error(t, err)
	assert.Equal(t, statusCode, resp.StatusCode)
	if statusCode == http.StatusBadRequest {
		assert.Equal(t, statusCode, err.(*Error).StatusCode)
		assert.Equal(t, message, err.(*Error).Message)
	} else {
		assert.Equal(t, message, err.Error())
	}
}

// Mock RoundTripper: easy mock neck request and verify it was called
func newMockRT() *mockRT {
	return &mockRT{}
}

type httpHandler func(req *http.Request) (*http.Response, error)

type mockRT struct {
	nextRoundTrip   httpHandler
	roundTripCalled bool // true after stubbed next RoundTrip was done
}

var resp200OK = http.Response{StatusCode: 200}

func (rt *mockRT) OnNextRoundTrip(handler httpHandler) {
	rt.nextRoundTrip = handler
	rt.roundTripCalled = false
}

func (rt *mockRT) RoundTrip(req *http.Request) (*http.Response, error) {
	if rt.nextRoundTrip == nil {
		panic("MockRT can not round trip a request because it has no request handler stub (please call rt.nextRoundTrip(handler) before making the request)")
	}
	resp, err := rt.nextRoundTrip(req)
	rt.roundTripCalled = true
	rt.nextRoundTrip = nil
	return resp, err
}

func (rt *mockRT) RoundTripCalled() bool {
	return rt.roundTripCalled
}

// Mock Statter
func newMockStatter() *mockStatter {
	statter, _ := statsd.NewNoopClient()
	return &mockStatter{
		Statter: statter,
		timings: make(map[string]int),
	}
}

type mockStatter struct {
	statsd.Statter
	mu      sync.Mutex
	timings map[string]int
}

var _ statsd.Statter = (*mockStatter)(nil)

func (s *mockStatter) TimingDuration(stat string, d time.Duration, sample float32) error {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.timings[stat]++
	return nil
}

func (s *mockStatter) TimingCounts(stat string) int {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.timings[stat]
}

func (s *mockStatter) AnyTimingCounts() int {
	s.mu.Lock()
	defer s.mu.Unlock()
	sum := 0
	for _, count := range s.timings {
		sum += count
	}
	return sum
}

// Fake cxtDialer
type fakeCtxDialer struct {
	Err error // set to fake an error on DialContext

	DialedAddrs int    // number of times DialContext was called
	LastAddress string // address used on last DialContext call
}

func (d *fakeCtxDialer) DialContext(ctx context.Context, network string, address string) (net.Conn, error) {
	d.DialedAddrs++
	d.LastAddress = address
	return nil, d.Err
}

// Mock RoundTripperWrapper
type mockRoundTripperWrapper struct {
	base  http.RoundTripper
	calls int
	onRT  func(*http.Request)
}

var _ http.RoundTripper = (*mockRoundTripperWrapper)(nil)

func (rt *mockRoundTripperWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
	rt.calls++
	if rt.onRT != nil {
		rt.onRT(req)
	}
	return rt.base.RoundTrip(req)
}

// Mock Closer: to check that a body was closed
type mockCloser struct {
	io.Reader
	closed bool
}

var _ io.ReadCloser = (*mockCloser)(nil)

func (c *mockCloser) Close() error {
	c.closed = true
	return nil
}

func TestModifyReqOpts(t *testing.T) {
	Convey("modifyReqOpts", t, func() {
		for _, scenario := range []struct {
			name             string
			conf             *ClientConf
			ro               ReqOpts
			expectedStatName string
		}{
			{
				name:             "without prefix",
				conf:             &ClientConf{},
				ro:               ReqOpts{StatName: "request_name"},
				expectedStatName: "request_name",
			},
			{
				name:             "with prefix",
				conf:             &ClientConf{StatNamePrefix: "service.test"},
				ro:               ReqOpts{StatName: "request_name"},
				expectedStatName: "service.test.request_name",
			},
		} {
			Convey(scenario.name, func() {
				c := client{
					conf: scenario.conf,
				}

				c.modifyReqOpts(&scenario.ro)

				if scenario.ro.StatName != scenario.expectedStatName {
					t.Fatalf("expected %s, got %s", scenario.expectedStatName, scenario.ro.StatName)
				}
			})
		}
	})
}
