package twitchclient

import (
	"context"
	"crypto/tls"
	"errors"
	"net/http"
	"net/http/httptest"
	"net/http/httptrace"
	"net/url"
	"testing"
	"time"

	"code.justin.tv/foundation/twitchclient/mocks"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestNoStats(t *testing.T) { // twitchclient should not attempt to send stats without a statter
	ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer ts.Close()

	httpClient := NewHTTPClient(ClientConf{
		Stats:           nil,
		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
	})

	req, _ := http.NewRequest(http.MethodGet, ts.URL, nil)
	req = req.WithContext(WithReqOpts(context.Background(), ReqOpts{
		StatName: "stat", // this will be ignored
	}))

	resp, err := httpClient.Do(req)
	require.NoError(t, err)
	require.Equal(t, http.StatusOK, resp.StatusCode) // everything is good, no panics or issues
}

func TestRegularStats(t *testing.T) { // should emit http tracing based on request option StatSampleRate
	ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer ts.Close()

	mockStatter := mocks.NewStatter()
	httpClient := NewHTTPClient(ClientConf{
		Stats:           mockStatter,
		DNSStatsPrefix:  "dns-prefix",
		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
	})

	req, _ := http.NewRequest(http.MethodGet, urlToLocalhost(ts.URL), nil)
	req = req.WithContext(WithReqOpts(context.Background(), ReqOpts{
		StatName:       "stat", // is ignored on http tracing stats
		StatSampleRate: 0.22,   // is also used for http tracing stats
	}))

	resp, err := httpClient.Do(req)
	require.NoError(t, err)
	require.Equal(t, http.StatusOK, resp.StatusCode)

	// verify that TimingDuration was called for each stat
	assert.Equal(t, 1, mockStatter.TimingCounts("get_connection.localhost.success"))
	assert.Equal(t, 1, mockStatter.TimingCounts("dns-prefix.localhost.success"))
	assert.Equal(t, 1, mockStatter.TimingCounts("dial.localhost.success"))
	assert.Equal(t, 1, mockStatter.TimingCounts("tls_handshake.localhost.success"))

	// verify that the sampleRate used is the one from the reqOpts
	assert.Equal(t, float32(0.22), mockStatter.TimingSample("get_connection.localhost.success"))
	assert.Equal(t, float32(0.22), mockStatter.TimingSample("dns-prefix.localhost.success"))
	assert.Equal(t, float32(0.22), mockStatter.TimingSample("dial.localhost.success"))
	assert.Equal(t, float32(0.22), mockStatter.TimingSample("tls_handshake.localhost.success"))

	// Make another request that will fail when reading the request body
	req, _ = http.NewRequest(http.MethodGet, urlToLocalhost(ts.URL), &failingReader{})
	// with no ReqOpts, should use the default sample value 0.1
	_, err = httpClient.Do(req)
	require.Error(t, err) // the request failed when reading the body

	// verity that some other stats were reported
	assert.Equal(t, int64(1), mockStatter.IncSum("req_write.localhost.failure"))
	assert.Equal(t, int64(1), mockStatter.IncSum("get_connection.localhost.reused"))
	assert.Equal(t, int64(1), mockStatter.IncSum("get_connection.localhost.was_idle"))
	assert.Equal(t, 1, mockStatter.TimingCounts("get_connection.localhost.idle_time"))

	// verity that the right sample rate is the expected
	assert.Equal(t, float32(1), mockStatter.IncSample("req_write.localhost.failure")) // not sampled
	assert.Equal(t, float32(0.1), mockStatter.IncSample("get_connection.localhost.reused"))
	assert.Equal(t, float32(0.1), mockStatter.IncSample("get_connection.localhost.was_idle"))
	assert.Equal(t, float32(0.1), mockStatter.TimingSample("get_connection.localhost.idle_time"))
}

func TestStatsOnFailingRequestToUnavailableServer(t *testing.T) {
	mockStatter := mocks.NewStatter()
	httpClient := NewHTTPClient(ClientConf{Stats: mockStatter})

	req, _ := http.NewRequest(http.MethodGet, "https://localhost:9988", nil)
	req = req.WithContext(WithReqOpts(context.Background(), ReqOpts{}))

	_, err := httpClient.Do(req)
	require.Error(t, err) // request should have failed because it can not connect to localhost service

	assert.Equal(t, 1, mockStatter.TimingCounts("dns.localhost.success"))   // localhost => 127.0.0.1
	assert.Equal(t, int64(2), mockStatter.IncSum("dial.localhost.failure")) // not able to connect to localhost

	assert.Equal(t, float32(0.1), mockStatter.TimingSample("dns.localhost.success"))
	assert.Equal(t, float32(1), mockStatter.IncSample("dial.localhost.failure")) // not sampled
}

func TestDNSStatsHook(t *testing.T) {
	mockStatter := mocks.NewStatter()

	startTrace := httptrace.DNSStartInfo{
		Host: "localhost",
	}

	dnsData := &dnsData{}
	hookDNSStart(dnsData)(startTrace)

	//will send a failure stat if the DNS lookup did not succeed.", func() {
	endTrace := httptrace.DNSDoneInfo{
		Err: errors.New("lul"),
	}
	hookDNSDone(mockStatter, "dns", dnsData, float32(.1))(endTrace)

	assert.Equal(t, int64(1), mockStatter.IncSum("dns.localhost.failure"))
	assert.Equal(t, float32(1), mockStatter.IncSample("dns.localhost.failure"))

	// will emit a stat if the lookup is coalesced
	endTrace = httptrace.DNSDoneInfo{
		Coalesced: true,
	}
	hookDNSDone(mockStatter, "dns", dnsData, float32(0.1))(endTrace)

	assert.Equal(t, int64(1), mockStatter.IncSum("dns.localhost.coalesced"))
	assert.Equal(t, float32(0.1), mockStatter.IncSample("dns.localhost.coalesced"))
}

func TestTLSStatsHook(t *testing.T) {
	mockStatter := mocks.NewStatter()
	dnsData := &dnsData{"localhost", time.Now()}
	tlsData := &tlsData{}
	hookTLSHandshakeStart(tlsData)()

	// will send a failure stat if the tls handshake did not succeed.
	hookTLSHandshakeDone(mockStatter, dnsData, tlsData, float32(.1))(tls.ConnectionState{}, errors.New("lulz"))

	assert.Equal(t, int64(1), mockStatter.IncSum("tls_handshake.localhost.failure"))
	assert.Equal(t, float32(1), mockStatter.IncSample("tls_handshake.localhost.failure"))
}

func TestSanitizeStat(t *testing.T) {
	assert.Equal(t, "example_com", sanitizeStat("example.com"))
	assert.Equal(t, "123_145_120", sanitizeStat("123.145.120"))
	assert.Equal(t, "__1", sanitizeStat("::1"))
}

var _benchmarkSanitizeStat string

func BenchmarkSanitizeState(b *testing.B) {
	in, expected := "192.186.0.1", "192_186_0_1"

	for i := 0; i < b.N; i++ {
		_benchmarkSanitizeStat = sanitizeStat(in)
	}

	if _benchmarkSanitizeStat != expected {
		b.Errorf("wanted %q, got %q", expected, _benchmarkSanitizeStat)
	}
}

// This variable is defined at the package level so the compiler doesn't optimize the call away.
var _benchmarkStatKey string

func BenchmarkStatKey(b *testing.B) {
	host, ip, route := "dns", "192_168_0_1", "POST__graphql"
	expectedKey := "dns.192_168_0_1.POST__graphql"

	for i := 0; i < b.N; i++ {
		_benchmarkStatKey = statKey(host, ip, route)
	}

	if _benchmarkStatKey != expectedKey {
		b.Errorf("wanted %q, got %q", expectedKey, _benchmarkStatKey)
	}
}

// Take a host URL and change the domain for "localhost".
// This is useful to trigger DNS stats when talking to a local test server (using localhost instead of 127.0.0.1)
func urlToLocalhost(hostURL string) string {
	parsedURL, _ := url.Parse(hostURL)
	return "https://localhost:" + parsedURL.Port()
}

// failingReader will always fail to read.
type failingReader struct{}

func (r *failingReader) Read(p []byte) (n int, err error) { return 0, errors.New("whoops") }
