package twitchhttp

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"strconv"
	"sync"
	"testing"
	"time"

	"code.justin.tv/chat/timing"
	"github.com/cactus/go-statsd-client/statsd"
	. "github.com/smartystreets/goconvey/convey"
	"github.com/stretchr/testify/assert"
	"golang.org/x/net/context"
)

func TestNewClient(t *testing.T) {
	Convey("NewClient", t, func() {
		Convey("errors when Host is empty", func() {
			_, err := NewClient(ClientConf{})
			assert.Error(t, err)
		})
	})
}

func TestNewRequest(t *testing.T) {
	Convey("NewRequest", t, func() {
		client, err := newClient(ClientConf{
			Host: "localhost",
		})
		assert.Nil(t, err)

		params := struct {
			Method string
			Path   string
			Body   io.Reader
		}{
			Method: "GET",
			Path:   "/path",
		}
		newRequest := func() (*http.Request, error) {
			return client.NewRequest(params.Method, params.Path, params.Body)
		}

		Convey("creates a valid request", func() {
			req, err := newRequest()
			assert.NoError(t, err)

			assert.Equal(t, req.Method, params.Method)

			assert.Equal(t, req.URL.Scheme, "http")
			assert.Equal(t, req.URL.Host, "localhost")
			assert.Equal(t, req.URL.Path, params.Path)

			assert.Equal(t, req.Body, params.Body)
		})
	})
}

func TestDo(t *testing.T) {
	Convey("Do", t, func() {
		client, err := newClient(ClientConf{
			Host: "localhost",
		})
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		ctx := context.Background()
		rt := newMockRoundTripper()

		doRequest := func(opts ReqOpts) (*http.Response, error) {
			client.transport = rt
			return client.Do(ctx, req, opts)
		}

		Convey("should make request to RoundTripper", func() {
			expectedResponse := &http.Response{StatusCode: 200}
			err := rt.StubRequest(req, func() <-chan responseErr {
				result := make(chan responseErr, 1)
				result <- responseErr{expectedResponse, nil}
				return result
			})
			assert.NoError(t, err)

			response, err := doRequest(ReqOpts{})
			assert.NoError(t, err)
			assert.Equal(t, response.StatusCode, expectedResponse.StatusCode)
			assert.Equal(t, rt.RoundTripCount(), 1)
			assert.Equal(t, rt.RequestCancelled(req), false)
		})

		Convey("should send the correct header to the roundtripper", func() {
			expectedResponse := &http.Response{StatusCode: 200}
			err := rt.StubRequest(req, func() <-chan responseErr {
				result := make(chan responseErr, 1)
				if len(req.Header.Get(TwitchAuthorizationHeader)) == 0 {
					result <- responseErr{nil, fmt.Errorf("fail")}
				} else {
					result <- responseErr{expectedResponse, nil}
				}
				return result
			})
			assert.NoError(t, err)
			response, err := doRequest(ReqOpts{AuthorizationToken: "thisisatoken"})
			assert.NoError(t, err)
			assert.Equal(t, response.StatusCode, expectedResponse.StatusCode)
			assert.Equal(t, rt.RoundTripCount(), 1)
			assert.Equal(t, rt.RequestCancelled(req), false)
		})

		Convey("should error if context deadline passes", func() {
			stubbedResponse := &http.Response{}
			err := rt.StubRequest(req, func() <-chan responseErr {
				result := make(chan responseErr, 1)
				go func() {
					time.Sleep(1 * time.Second)
					result <- responseErr{stubbedResponse, nil}
				}()
				return result
			})
			assert.NoError(t, err)

			ctx, _ = context.WithTimeout(ctx, 10*time.Millisecond)
			_, err = doRequest(ReqOpts{})
			assert.Error(t, err)
			assert.Equal(t, rt.RoundTripCount(), 1)
			assert.Equal(t, rt.RequestCancelled(req), true)
		})
	})
}

func TestRequestTimings(t *testing.T) {
	Convey("Do", t, func() {
		stats := newMockStatter()

		client, err := newClient(ClientConf{
			Host:  "localhost",
			Stats: stats,
		})
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		rt := newMockRoundTripper()
		respErr := responseErr{
			response: &http.Response{
				StatusCode: 200,
			},
			err: nil,
		}
		err = rt.StubRequest(req, func() <-chan responseErr {
			result := make(chan responseErr, 1)
			result <- respErr
			return result
		})
		assert.NoError(t, err)

		ctx := context.Background()
		opts := ReqOpts{}

		doRequest := func() (*http.Response, error) {
			client.transport = rt
			return client.Do(ctx, req, opts)
		}

		Convey("should track request timings if opts.StatName exists", func() {
			opts = ReqOpts{
				StatName: "request_stat_name",
			}

			Convey("with the response status if there is no error", func() {
				_, err := doRequest()
				assert.NoError(t, err)
				assert.Equal(t, stats.AnyTimingCounts(), 1)
				assert.Equal(t, stats.TimingCounts("request_stat_name.200"), 1)
			})

			Convey("with response status 0 if there is an error", func() {
				respErr = responseErr{
					response: nil,
					err:      errors.New("some error"),
				}
				_, err := doRequest()
				assert.Error(t, err)
				assert.Equal(t, stats.AnyTimingCounts(), 1)
				assert.Equal(t, stats.TimingCounts("request_stat_name.0"), 1)
			})
		})

		Convey("should not track request timings if opts.StatName is empty", func() {
			_, err := doRequest()
			assert.NoError(t, err)
			assert.Equal(t, stats.AnyTimingCounts(), 0)
		})
	})
}

func TestDNSTimings(t *testing.T) {
	Convey("dialer", t, func() {
		stats := newMockStatter()
		dial := &mockDial{}

		ipv4 := net.ParseIP("192.168.1.1")
		ipv6 := net.ParseIP("2001:4860:0:2001::68")

		createDialer := func(stats statsd.Statter) *dialer {
			d := newDialer(stats, "")
			d.dial = dial.Dial
			d.lookupIP = func(host string) ([]net.IP, error) {
				return []net.IP{ipv4, ipv6}, nil
			}
			return d
		}

		assertDialedLookupIP := func() {
			addr, _, err := net.SplitHostPort(dial.DialedAddrs()[0])
			assert.NoError(t, err)
			found := false
			for _, ip := range []net.IP{ipv4, ipv6} {
				if ip.String() == addr {
					found = true
				}
			}
			assert.True(t, found)
		}

		Convey("when stats are enabled", func() {
			dialer := createDialer(stats)

			Convey("and dialing an IP", func() {
				_, err := dialer.Dial("tcp", "10.12.0.12:80")
				assert.NoError(t, err)
				assert.Equal(t, len(dial.DialedAddrs()), 1)
				assert.Equal(t, dial.DialedAddrs()[0], "10.12.0.12:80")
				assert.Equal(t, stats.AnyTimingCounts(), 0)
			})
			Convey("and dialing a DNS host", func() {
				Convey("with network tcp", func() {
					_, err := dialer.Dial("tcp", "test.twitch.tv:80")
					assert.NoError(t, err)
					assert.Equal(t, len(dial.DialedAddrs()), 1)
					assertDialedLookupIP()
					assert.Equal(t, stats.AnyTimingCounts(), 1)
					assert.Equal(t, stats.TimingCounts("dns.test_twitch_tv.success"), 1)
				})
				Convey("with network tcp6", func() {
					_, err := dialer.Dial("tcp6", "test.twitch.tv:80")
					assert.NoError(t, err)
					assert.Equal(t, len(dial.DialedAddrs()), 1)
					addr, _, err := net.SplitHostPort(dial.DialedAddrs()[0])
					assert.NoError(t, err)
					assert.Equal(t, addr, ipv6.String())
					assert.Equal(t, stats.AnyTimingCounts(), 1)
					assert.Equal(t, stats.TimingCounts("dns.test_twitch_tv.success"), 1)
				})
				Convey("with network tcp4", func() {
					_, err := dialer.Dial("tcp4", "test.twitch.tv:80")
					assert.NoError(t, err)
					assert.Equal(t, len(dial.DialedAddrs()), 1)
					addr, _, err := net.SplitHostPort(dial.DialedAddrs()[0])
					assert.NoError(t, err)
					assert.Equal(t, addr, ipv4.String())
					assert.Equal(t, stats.AnyTimingCounts(), 1)
					assert.Equal(t, stats.TimingCounts("dns.test_twitch_tv.success"), 1)
				})
			})
		})

		Convey("when stats are disabled", func() {
			dialer := createDialer(nil)

			Convey("and dialing an IP", func() {
				_, err := dialer.Dial("tcp", "10.12.0.12:80")
				assert.NoError(t, err)
				assert.Equal(t, len(dial.DialedAddrs()), 1)
				assert.Equal(t, dial.DialedAddrs()[0], "10.12.0.12:80")
			})
			Convey("and dialing a DNS host", func() {
				_, err := dialer.Dial("tcp", "test.twitch.tv:80")
				assert.NoError(t, err)
				assert.Equal(t, len(dial.DialedAddrs()), 1)
				assert.Equal(t, dial.DialedAddrs()[0], "test.twitch.tv:80")
			})
		})
	})
}

func TestXactTimings(t *testing.T) {
	Convey("Do", t, func() {
		Convey("with Xact in context", func() {
			stats := newMockStatter()
			xact := &timing.Xact{
				Stats: stats,
			}
			xact.AddName("test")
			xact.Start()
			ctx := timing.XactContext(context.Background(), xact)

			makeRequest := func(xactName string) {
				client, err := newClient(ClientConf{
					Host:           "localhost",
					TimingXactName: xactName,
				})
				assert.NoError(t, err)

				req, err := client.NewRequest("GET", "/path", nil)
				assert.NoError(t, err)

				rt := newMockRoundTripper()
				respErr := responseErr{
					response: &http.Response{
						StatusCode: 200,
					},
					err: nil,
				}
				err = rt.StubRequest(req, func() <-chan responseErr {
					result := make(chan responseErr, 1)
					result <- respErr
					return result
				})
				assert.NoError(t, err)

				client.transport = rt

				client.Do(ctx, req, ReqOpts{})
				xact.End("done")
			}

			Convey("but without explicit TimingXactName", func() {
				makeRequest("")
				assert.Equal(t, stats.AnyTimingCounts(), 2)
				assert.Equal(t, stats.TimingCounts("test.done.total"), 1)
				assert.Equal(t, stats.TimingCounts("test.done.twitchhttp"), 1)
			})
			Convey("with explicit TimingXactName", func() {
				makeRequest("someclient")
				assert.Equal(t, stats.AnyTimingCounts(), 2)
				assert.Equal(t, stats.TimingCounts("test.done.total"), 1)
				assert.Equal(t, stats.TimingCounts("test.done.someclient"), 1)
			})
		})
	})
}

func TestResolveHTTPClient(t *testing.T) {
	Convey("resolveHTTPClient", t, func() {
		var wrapper *mockRoundTripperWrapper
		wrapRT := func(rt http.RoundTripper) http.RoundTripper {
			wrapper = &mockRoundTripperWrapper{rt, 0}
			return wrapper
		}

		ctx := context.Background()
		client, err := newClient(ClientConf{
			Host:                 "localhost",
			RoundTripperWrappers: []func(http.RoundTripper) http.RoundTripper{wrapRT},
		})

		rt := newMockRoundTripper()
		client.transport = rt

		c, err := client.resolveHTTPClient(ctx)
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		c.Do(req)
		assert.Equal(t, wrapper.calls, 1, "Wrapper RoundTripper should be called")
		assert.Equal(t, rt.RoundTripCount(), 1, "Underlying RoundTripper should be called")
	})
}

func TestDoJSON(t *testing.T) {
	Convey("DoJSON", t, func() {
		client, err := newClient(ClientConf{
			Host: "localhost",
		})
		assert.NoError(t, err)

		req, err := client.NewRequest("GET", "/path", nil)
		assert.NoError(t, err)

		var expectedResponse *http.Response
		var expectedError error
		var data *Error
		var body *mockCloser

		ctx := context.Background()
		rt := newMockRoundTripper()

		doRequest := func(status int, msg string) (*http.Response, error) {
			if msg == "" {
				msg = `{"message":"test_message"}`
			}
			body = &mockCloser{bytes.NewBufferString(msg), false}
			expectedResponse = &http.Response{
				Status:     strconv.Itoa(status),
				StatusCode: status,
				Body:       body,
			}

			err := rt.StubRequest(req, func() <-chan responseErr {
				result := make(chan responseErr, 1)
				result <- responseErr{expectedResponse, expectedError}
				return result
			})
			assert.NoError(t, err)

			client.transport = rt
			return client.DoJSON(ctx, &data, req, ReqOpts{})
		}

		Convey("should decode a successful response", func() {
			response, err := doRequest(200, "")

			assert.NoError(t, err)
			assert.Equal(t, "test_message", data.Message)
			assert.Equal(t, true, body.closed)

			assert.Equal(t, expectedResponse.StatusCode, response.StatusCode)
			assert.Equal(t, 1, rt.RoundTripCount())
			assert.Equal(t, false, rt.RequestCancelled(req))
		})

		Convey("should error if Do errors", func() {
			expectedError = errors.New("error in Do")
			response, err := doRequest(0, "")

			assert.Error(t, err)
			assert.Equal(t, false, body.closed)

			assert.Equal(t, (*http.Response)(nil), response)
			assert.Equal(t, 1, rt.RoundTripCount())
			assert.Equal(t, false, rt.RequestCancelled(req))
		})

		Convey("should error for 500's", func() {
			response, err := doRequest(500, "")

			assert.Error(t, err)
			assert.Equal(t, "500", err.Error())
			assert.Equal(t, true, body.closed)

			assert.Equal(t, expectedResponse.StatusCode, response.StatusCode)
			assert.Equal(t, 1, rt.RoundTripCount())
			assert.Equal(t, false, rt.RequestCancelled(req))
		})

		Convey("should *Error for parsable 400's", func() {
			response, err := doRequest(400, "")

			assert.Error(t, err)
			assert.Equal(t, 400, err.(*Error).StatusCode)
			assert.Equal(t, "test_message", err.(*Error).Message)
			assert.Equal(t, true, body.closed)

			assert.Equal(t, expectedResponse.StatusCode, response.StatusCode)
			assert.Equal(t, 1, rt.RoundTripCount())
			assert.Equal(t, false, rt.RequestCancelled(req))
		})

		Convey("should *Error for unparsable 400's", func() {
			response, err := doRequest(400, `{`)

			assert.Error(t, err)
			assert.Equal(t, 400, err.(*Error).StatusCode)
			assert.Equal(t, "Unable to read response body: unexpected EOF", err.(*Error).Message)
			assert.Equal(t, true, body.closed)

			assert.Equal(t, expectedResponse.StatusCode, response.StatusCode)
			assert.Equal(t, 1, rt.RoundTripCount())
			assert.Equal(t, false, rt.RequestCancelled(req))
		})

		Convey("should error for unparsable 200's", func() {
			response, err := doRequest(200, `{`)

			assert.Error(t, err)
			assert.Equal(t, "Unable to read response body: unexpected EOF", err.Error())
			assert.Equal(t, true, body.closed)

			assert.Equal(t, expectedResponse.StatusCode, response.StatusCode)
			assert.Equal(t, 1, rt.RoundTripCount())
			assert.Equal(t, false, rt.RequestCancelled(req))
		})
	})
}

func newClient(conf ClientConf) (*client, error) {
	c, err := NewClient(conf)
	return c.(*client), err
}

func newMockRoundTripper() *mockRoundTripper {
	return &mockRoundTripper{
		stubs:   make(map[string]func() <-chan responseErr),
		pending: make(map[string]chan struct{}),
		cancels: make(map[string]struct{}),
	}
}

type mockRoundTripper struct {
	mu             sync.Mutex
	stubs          map[string]func() <-chan responseErr
	pending        map[string]chan struct{}
	cancels        map[string]struct{}
	roundtripCalls int
}

var _ http.RoundTripper = (*mockRoundTripper)(nil)

func (rt *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
	key := requestKey(req)

	rt.mu.Lock()
	rt.roundtripCalls += 1

	fn, ok := rt.stubs[key]
	if !ok {
		rt.mu.Unlock()
		return nil, fmt.Errorf("unexpected RoundTrip with request: %v", req)
	}

	cancelled := rt.pending[key]
	rt.mu.Unlock()

	result := fn()
	select {
	case respErr := <-result:
		return respErr.response, respErr.err
	case <-cancelled:
		return nil, errors.New("request cancelled")
	}
}

func (rt *mockRoundTripper) CloseIdleConnections() {}

func (rt *mockRoundTripper) CancelRequest(req *http.Request) {
	key := requestKey(req)
	rt.mu.Lock()
	defer rt.mu.Unlock()
	if cancelled, ok := rt.pending[key]; ok {
		select {
		case cancelled <- struct{}{}:
		default:
		}
		rt.cancels[key] = struct{}{}
	}
}

func (rt *mockRoundTripper) RoundTripCount() int {
	rt.mu.Lock()
	defer rt.mu.Unlock()
	return rt.roundtripCalls
}

func (rt *mockRoundTripper) RequestCancelled(req *http.Request) bool {
	key := requestKey(req)
	rt.mu.Lock()
	defer rt.mu.Unlock()
	_, ok := rt.cancels[key]
	return ok
}

func (rt *mockRoundTripper) StubRequest(req *http.Request, fn func() <-chan responseErr) error {
	key := requestKey(req)

	rt.mu.Lock()
	defer rt.mu.Unlock()

	if _, ok := rt.stubs[key]; ok {
		return fmt.Errorf("attempted to stub request twice: %v", req)
	}

	rt.pending[key] = make(chan struct{}, 1)
	rt.stubs[key] = fn
	return nil
}

func requestKey(req *http.Request) string {
	return fmt.Sprintf("%s|%s", req.Method, req.URL)
}

type responseErr struct {
	response *http.Response
	err      error
}

func newMockStatter() *mockStatter {
	statter, _ := statsd.NewNoopClient()
	return &mockStatter{
		Statter: statter,
		timings: make(map[string]int),
	}
}

type mockStatter struct {
	statsd.Statter
	mu      sync.Mutex
	timings map[string]int
}

var _ statsd.Statter = (*mockStatter)(nil)

func (s *mockStatter) TimingDuration(stat string, d time.Duration, sample float32) error {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.timings[stat] += 1
	return nil
}

func (s *mockStatter) TimingCounts(stat string) int {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.timings[stat]
}

func (s *mockStatter) AnyTimingCounts() int {
	s.mu.Lock()
	defer s.mu.Unlock()
	sum := 0
	for _, count := range s.timings {
		sum += count
	}
	return sum
}

type mockDial struct {
	mu    sync.Mutex
	addrs []string
}

func (d *mockDial) Dial(network string, addr string) (net.Conn, error) {
	d.mu.Lock()
	defer d.mu.Unlock()
	d.addrs = append(d.addrs, addr)
	return nil, nil
}

func (d *mockDial) DialedAddrs() []string {
	d.mu.Lock()
	defer d.mu.Unlock()
	return d.addrs[:]
}

type mockRoundTripperWrapper struct {
	base  http.RoundTripper
	calls int
}

var _ http.RoundTripper = (*mockRoundTripperWrapper)(nil)

func (rt *mockRoundTripperWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
	rt.calls++
	return rt.base.RoundTrip(req)
}

type mockCloser struct {
	io.Reader
	closed bool
}

var _ io.ReadCloser = (*mockCloser)(nil)

func (c *mockCloser) Close() error {
	c.closed = true
	return nil
}
