package twitchhttp

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"strconv"
	"sync"
	"testing"
	"time"

	"code.justin.tv/chat/timing"

	"github.com/cactus/go-statsd-client/statsd"
	. "github.com/smartystreets/goconvey/convey"
	"github.com/stretchr/testify/assert"
	"golang.org/x/net/context"
)

func TestNewClient(t *testing.T) {
	Convey("NewClient", t, func() {
		Convey("errors when Host is empty", func() {
			_, err := NewClient(ClientConf{})
			assert.Error(t, err)
		})
	})
}

func TestNewRequest(t *testing.T) {
	Convey("NewRequest", t, func() {
		client, err := NewClient(ClientConf{
			Host: "localhost",
		})
		assert.Nil(t, err)

		Convey("creates a valid request", func() {
			req, err := client.NewRequest("GET", "/path", nil)
			assert.NoError(t, err)

			assert.Equal(t, req.Method, "GET")
			assert.Equal(t, req.URL.Scheme, "http")
			assert.Equal(t, req.URL.Host, "localhost")
			assert.Equal(t, req.URL.Path, "/path")
			assert.Equal(t, req.Body, nil)
		})
	})
}

func TestDo(t *testing.T) {
	Convey("Do", t, func() {
		mockRT := newMockRT()
		client, err := NewClient(ClientConf{
			Host:             "localhost",
			BaseRoundTripper: mockRT,
		})
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		Convey("should make request to RoundTripper", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				return &resp200OK, nil
			})

			ctx := context.Background()
			resp, err := client.Do(ctx, req, ReqOpts{})
			assert.NoError(t, err)
			assert.Equal(t, 200, resp.StatusCode)
			assert.True(t, mockRT.RoundTripCalled())
			assert.Nil(t, ctx.Err())
		})

		Convey("should send log ID header in the roundtripper", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				logIDHeader := req.Header.Get("X-Ctxlog-LogID")
				assert.NotEmpty(t, logIDHeader)
				return &resp200OK, nil
			})

			ctx := context.Background()
			resp, err := client.Do(ctx, req, ReqOpts{AuthorizationToken: "faketokendude"})
			assert.NoError(t, err)
			assert.Equal(t, 200, resp.StatusCode)
			assert.True(t, mockRT.RoundTripCalled())
			assert.Nil(t, ctx.Err())
		})

		Convey("should send the authorization token header to the roundtripper", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				expectedHeader := req.Header.Get(TwitchAuthorizationHeader)
				assert.Equal(t, "faketokendude", expectedHeader)
				return &resp200OK, nil
			})

			ctx := context.Background()
			resp, err := client.Do(ctx, req, ReqOpts{AuthorizationToken: "faketokendude"})
			assert.NoError(t, err)
			assert.Equal(t, 200, resp.StatusCode)
			assert.True(t, mockRT.RoundTripCalled())
			assert.Nil(t, ctx.Err())
		})

		Convey("should send the client row ID header to the roundtripper", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				expectedHeader := req.Header.Get(TwitchClientRowIDHeader)
				assert.Equal(t, "12345", expectedHeader)
				return &resp200OK, nil
			})

			ctx := context.Background()
			resp, err := client.Do(ctx, req, ReqOpts{ClientRowID: "12345"})
			assert.NoError(t, err)
			assert.Equal(t, 200, resp.StatusCode)
			assert.True(t, mockRT.RoundTripCalled())
			assert.Nil(t, ctx.Err())
		})

		Convey("should error if context deadline passes", func() {
			timeoutDuration := 1 * time.Microsecond
			sleepDuration := timeoutDuration + 10*time.Millisecond // make sure it has time to trigger the timeout (even if the CPU is busy)

			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				time.Sleep(sleepDuration)
				return &http.Response{}, nil
			})

			ctx, _ := context.WithTimeout(context.Background(), timeoutDuration)
			_, err := client.Do(ctx, req, ReqOpts{})
			assert.NoError(t, err) // NOTE: maybe this should actually return an error ?
			assert.True(t, mockRT.RoundTripCalled())
			assert.Equal(t, "context deadline exceeded", ctx.Err().Error())
		})

		Convey("should set request context if not set", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				reqCtx := req.Context()
				assert.NotNil(t, reqCtx.Value("testKey"))
				assert.Equal(t, reqCtx.Value("testKey").(string), "testVal")
				return &resp200OK, nil
			})

			ctx := context.WithValue(context.Background(), "testKey", "testVal")
			_, err := client.Do(ctx, req, ReqOpts{})
			assert.NoError(t, err)
			assert.True(t, mockRT.RoundTripCalled())
		})

		Convey("chitin will merge ctx param and req.Context() if both are being set", func() {
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				reqCtx := req.Context()
				assert.NotNil(t, reqCtx.Value("ctxKey1"))
				assert.NotNil(t, reqCtx.Value("reqKey"))
				assert.Equal(t, "ctxVal", reqCtx.Value("ctxKey1").(string))
				assert.Equal(t, "reqVal", reqCtx.Value("reqKey").(string)) // both values are present in the same, merged context
				return &resp200OK, nil
			})

			ctx := context.WithValue(context.Background(), "ctxKey1", "ctxVal")
			reqCtx := context.WithValue(context.Background(), "reqKey", "reqVal")
			req = req.WithContext(reqCtx)
			_, err := client.Do(ctx, req, ReqOpts{})
			assert.NoError(t, err)
			assert.True(t, mockRT.RoundTripCalled())
		})
	})
}

func TestStatsRequestTimings(t *testing.T) {
	Convey("Do", t, func() {
		mockStats := newMockStatter()
		mockRT := newMockRT()
		client, err := NewClient(ClientConf{
			Host:             "localhost",
			Stats:            mockStats,
			BaseRoundTripper: mockRT,
		})
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		Convey("with opts.StatName", func() {
			reqOptsWithStat := ReqOpts{StatName: "request_stat_name"}

			respStatuses := []int{200, 300, 400, 500}
			for status := range respStatuses {
				Convey(fmt.Sprintf("should track request timings with response status %d", status), func() {
					mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
						return &http.Response{StatusCode: status}, nil
					})
					_, err := client.Do(context.Background(), req, reqOptsWithStat)
					assert.NoError(t, err)
					assert.True(t, mockRT.RoundTripCalled())
					assert.Equal(t, 1, mockStats.AnyTimingCounts())
					assert.Equal(t, 1, mockStats.TimingCounts(fmt.Sprintf("request_stat_name.%d", status)))
				})
			}

			Convey("should track request timings with response status 0 when there is an error", func() {
				mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
					return nil, errors.New("some error")
				})
				_, err := client.Do(context.Background(), req, reqOptsWithStat)
				assert.Error(t, err)
				assert.True(t, mockRT.RoundTripCalled())
				assert.Equal(t, 1, mockStats.AnyTimingCounts())
				assert.Equal(t, 1, mockStats.TimingCounts("request_stat_name.0"))
			})
		})
		Convey("without opts.StatName", func() {
			reqOptsWithNOStat := ReqOpts{}

			Convey("should not track request timings", func() {
				mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
					return &resp200OK, nil
				})
				_, err := client.Do(context.Background(), req, reqOptsWithNOStat) // opts.StatName is empty
				assert.NoError(t, err)
				assert.True(t, mockRT.RoundTripCalled())
				assert.Equal(t, 0, mockStats.AnyTimingCounts())
			})
		})
	})
}

func TestDNSTimings(t *testing.T) {
	Convey("dialer", t, func() {
		ipv4 := net.ParseIP("192.168.1.1")
		ipv6 := net.ParseIP("2001:4860:0:2001::68")
		mockDial := &mockDial{}
		createMockDialer := func(stats statsd.Statter) *dialer {
			d := newDialer(stats, "", false)
			d.dial = mockDial.Dial
			d.lookupIP = func(host string) ([]net.IP, error) {
				return []net.IP{ipv4, ipv6}, nil
			}
			return d
		}

		assertDialedLookupIP := func() {
			addr, _, err := net.SplitHostPort(mockDial.DialedAddrs()[0])
			assert.NoError(t, err)
			found := false
			for _, ip := range []net.IP{ipv4, ipv6} {
				if ip.String() == addr {
					found = true
				}
			}
			assert.True(t, found)
		}

		Convey("when stats are enabled", func() {
			mockStats := newMockStatter()
			dialer := createMockDialer(mockStats)

			Convey("and dialing an IP", func() {
				_, err := dialer.Dial("tcp", "10.12.0.12:80")
				assert.NoError(t, err)
				assert.Equal(t, len(mockDial.DialedAddrs()), 1)
				assert.Equal(t, mockDial.DialedAddrs()[0], "10.12.0.12:80")
				assert.Equal(t, mockStats.AnyTimingCounts(), 0)
			})
			Convey("and dialing a DNS host", func() {
				Convey("with network tcp", func() {
					_, err := dialer.Dial("tcp", "test.twitch.tv:80")
					assert.NoError(t, err)
					assert.Equal(t, len(mockDial.DialedAddrs()), 1)
					assertDialedLookupIP()
					assert.Equal(t, mockStats.AnyTimingCounts(), 1)
					assert.Equal(t, mockStats.TimingCounts("dns.test_twitch_tv.success"), 1)
				})
				Convey("with network tcp6", func() {
					_, err := dialer.Dial("tcp6", "test.twitch.tv:80")
					assert.NoError(t, err)
					assert.Equal(t, len(mockDial.DialedAddrs()), 1)
					addr, _, err := net.SplitHostPort(mockDial.DialedAddrs()[0])
					assert.NoError(t, err)
					assert.Equal(t, addr, ipv6.String())
					assert.Equal(t, mockStats.AnyTimingCounts(), 1)
					assert.Equal(t, mockStats.TimingCounts("dns.test_twitch_tv.success"), 1)
				})
				Convey("with network tcp4", func() {
					_, err := dialer.Dial("tcp4", "test.twitch.tv:80")
					assert.NoError(t, err)
					assert.Equal(t, len(mockDial.DialedAddrs()), 1)
					addr, _, err := net.SplitHostPort(mockDial.DialedAddrs()[0])
					assert.NoError(t, err)
					assert.Equal(t, addr, ipv4.String())
					assert.Equal(t, mockStats.AnyTimingCounts(), 1)
					assert.Equal(t, mockStats.TimingCounts("dns.test_twitch_tv.success"), 1)
				})
			})
		})

		Convey("when stats are disabled", func() {
			dialer := createMockDialer(nil)

			Convey("and dialing an IP", func() {
				_, err := dialer.Dial("tcp", "10.12.0.12:80")
				assert.NoError(t, err)
				assert.Equal(t, len(mockDial.DialedAddrs()), 1)
				assert.Equal(t, mockDial.DialedAddrs()[0], "10.12.0.12:80")
			})
			Convey("and dialing a DNS host", func() {
				_, err := dialer.Dial("tcp", "test.twitch.tv:80")
				assert.NoError(t, err)
				assert.Equal(t, len(mockDial.DialedAddrs()), 1)
				assert.Equal(t, mockDial.DialedAddrs()[0], "test.twitch.tv:80")
			})
		})
	})
}

func TestXactTimings(t *testing.T) {
	Convey("Do with Xact in context", t, func() {
		mockStats := newMockStatter()

		xact := &timing.Xact{Stats: mockStats}
		xact.AddName("test")
		xact.Start()

		ctx := timing.XactContext(context.Background(), xact)

		makeRequest := func(xactName string) {
			mockRT := newMockRT()
			client, err := NewClient(ClientConf{
				Host:             "localhost",
				TimingXactName:   xactName,
				BaseRoundTripper: mockRT,
			})
			assert.NoError(t, err)

			req, err := client.NewRequest("GET", "/path", nil)
			assert.NoError(t, err)

			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				return &resp200OK, nil
			})

			client.Do(ctx, req, ReqOpts{})
			assert.True(t, mockRT.RoundTripCalled())
			xact.End("done")
		}

		Convey("but without explicit TimingXactName", func() {
			makeRequest("")
			assert.Equal(t, 2, mockStats.AnyTimingCounts())
			assert.Equal(t, 1, mockStats.TimingCounts("test.done.total"))
			assert.Equal(t, 1, mockStats.TimingCounts("test.done.twitchhttp"))
		})
		Convey("with explicit TimingXactName", func() {
			makeRequest("someclient")
			assert.Equal(t, 2, mockStats.AnyTimingCounts())
			assert.Equal(t, 1, mockStats.TimingCounts("test.done.total"))
			assert.Equal(t, 1, mockStats.TimingCounts("test.done.someclient"))
		})
	})
}

func TestDoJSON(t *testing.T) {
	Convey("DoJSON", t, func() {
		mockRT := newMockRT()
		client, err := NewClient(ClientConf{
			Host:             "localhost",
			BaseRoundTripper: mockRT,
		})
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		var data *Error
		var body *mockCloser
		fakeDoJSON := func(statusResp int, jsonResp string, errorResp error) (*http.Response, error) {
			body = &mockCloser{bytes.NewBufferString(jsonResp), false}
			mockRT.OnNextRoundTrip(func(req *http.Request) (*http.Response, error) {
				resp := &http.Response{
					Status:     strconv.Itoa(statusResp),
					StatusCode: statusResp,
					Body:       body,
				}
				return resp, errorResp
			})
			resp, err := client.DoJSON(context.Background(), &data, req, ReqOpts{})
			assert.True(t, mockRT.RoundTripCalled())
			return resp, err
		}

		Convey("should decode a successful response", func() {
			resp, err := fakeDoJSON(200, `{"message":"foobar"}`, nil)

			assert.NoError(t, err)
			assert.Equal(t, "foobar", data.Message)
			assert.Equal(t, true, body.closed)
			assert.Equal(t, 200, resp.StatusCode)
		})

		Convey("should error if Do errors", func() {
			expectedError := errors.New("error in Do")
			resp, err := fakeDoJSON(0, `{"message":"nope"}`, expectedError)

			assert.Error(t, err)
			assert.Equal(t, false, body.closed)

			assert.Equal(t, (*http.Response)(nil), resp)
		})

		Convey("should error for 500's", func() {
			resp, err := fakeDoJSON(500, `{"message":"my bad"}`, nil)

			assert.Error(t, err)
			assert.Equal(t, 500, resp.StatusCode)
			assert.Equal(t, "500: {\"message\":\"my bad\"}", err.Error())
			assert.Equal(t, true, body.closed)
		})

		Convey("should *Error for parsable 400's", func() {
			resp, err := fakeDoJSON(400, `{"message":"your bad"}`, nil)

			assert.Error(t, err)
			assert.Equal(t, 400, err.(*Error).StatusCode)
			assert.Equal(t, 400, resp.StatusCode)
			assert.Equal(t, "your bad", err.(*Error).Message)
			assert.Equal(t, true, body.closed)
		})

		Convey("should *Error for unparsable 400's", func() {
			resp, err := fakeDoJSON(400, `{`, nil)

			assert.Error(t, err)
			assert.Equal(t, 400, err.(*Error).StatusCode)
			assert.Equal(t, 400, resp.StatusCode)
			assert.Equal(t, "Unable to read response body: unexpected EOF", err.(*Error).Message)
			assert.Equal(t, true, body.closed)
		})

		Convey("should error for unparsable 200's", func() {
			resp, err := fakeDoJSON(200, `{`, nil)

			assert.Error(t, err)
			assert.Equal(t, 200, resp.StatusCode)
			assert.Equal(t, "Unable to read response body: unexpected EOF", err.Error())
			assert.Equal(t, true, body.closed)
		})

		Convey("should not attempt to deserialize if 204 No Content", func() {
			resp, err := fakeDoJSON(204, "", nil)

			assert.NoError(t, err)
			assert.Equal(t, 204, resp.StatusCode)
			assert.Nil(t, data)
			assert.Equal(t, true, body.closed)
		})
	})
}

// Mock RoundTripper: easy mock neck request and verify it was called
func newMockRT() *mockRT {
	return &mockRT{}
}

type httpHandler func(req *http.Request) (*http.Response, error)

type mockRT struct {
	nextRoundTrip   httpHandler
	roundTripCalled bool // true after stubbed next RoundTrip was done
}

var resp200OK = http.Response{StatusCode: 200}

func (rt *mockRT) OnNextRoundTrip(handler httpHandler) {
	rt.nextRoundTrip = handler
	rt.roundTripCalled = false
}

func (rt *mockRT) RoundTrip(req *http.Request) (*http.Response, error) {
	if rt.nextRoundTrip == nil {
		panic("MockRT can not round trip a request because it has no request handler stub (please call rt.nextRoundTrip(handler) before making the request)")
	}
	resp, err := rt.nextRoundTrip(req)
	rt.roundTripCalled = true
	rt.nextRoundTrip = nil
	return resp, err
}

func (rt *mockRT) RoundTripCalled() bool {
	return rt.roundTripCalled
}

// Mock Statter
func newMockStatter() *mockStatter {
	statter, _ := statsd.NewNoopClient()
	return &mockStatter{
		Statter: statter,
		timings: make(map[string]int),
	}
}

type mockStatter struct {
	statsd.Statter
	mu      sync.Mutex
	timings map[string]int
}

var _ statsd.Statter = (*mockStatter)(nil)

func (s *mockStatter) TimingDuration(stat string, d time.Duration, sample float32) error {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.timings[stat] += 1
	return nil
}

func (s *mockStatter) TimingCounts(stat string) int {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.timings[stat]
}

func (s *mockStatter) AnyTimingCounts() int {
	s.mu.Lock()
	defer s.mu.Unlock()
	sum := 0
	for _, count := range s.timings {
		sum += count
	}
	return sum
}

// Mock Dial
type mockDial struct {
	mu    sync.Mutex
	addrs []string
}

func (d *mockDial) Dial(network string, addr string) (net.Conn, error) {
	d.mu.Lock()
	defer d.mu.Unlock()
	d.addrs = append(d.addrs, addr)
	return nil, nil
}

func (d *mockDial) DialedAddrs() []string {
	d.mu.Lock()
	defer d.mu.Unlock()
	return d.addrs[:]
}

type mockRoundTripperWrapper struct {
	base  http.RoundTripper
	calls int
	onRT  func(*http.Request)
}

var _ http.RoundTripper = (*mockRoundTripperWrapper)(nil)

func (rt *mockRoundTripperWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
	rt.calls++
	if rt.onRT != nil {
		rt.onRT(req)
	}
	return rt.base.RoundTrip(req)
}

// Mock Closer: to check that a body was closed
type mockCloser struct {
	io.Reader
	closed bool
}

var _ io.ReadCloser = (*mockCloser)(nil)

func (c *mockCloser) Close() error {
	c.closed = true
	return nil
}
