package spade_test

import (
	"bytes"
	"context"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"net/url"
	"reflect"
	"testing"
	"time"

	"code.justin.tv/common/spade-client-go/spade"
)

type newClientTest struct {
	opts        []spade.InitFunc
	expectError bool
}

func TestNewClient(t *testing.T) {
	tests := []newClientTest{
		{
			opts:        []spade.InitFunc{},
			expectError: false,
		},
		{
			opts:        []spade.InitFunc{spade.InitMaxConcurrency(0)},
			expectError: true,
		},
		{
			opts:        []spade.InitFunc{spade.InitMaxConcurrency(-10)},
			expectError: true,
		},
		{
			opts:        []spade.InitFunc{spade.InitMaxConcurrency(10)},
			expectError: false,
		},
		{
			opts:        []spade.InitFunc{spade.InitBatchSize(10)},
			expectError: false,
		},
		{
			opts:        []spade.InitFunc{spade.InitBatchSize(0)},
			expectError: true,
		},
		{
			opts:        []spade.InitFunc{spade.InitBatchSize(-10)},
			expectError: true,
		},
		{
			opts: []spade.InitFunc{
				spade.InitBaseURL(url.URL{
					Host: "valid_host",
				}),
			},
			expectError: false,
		},
		{
			opts: []spade.InitFunc{
				spade.InitBaseURL(url.URL{}),
			},
			expectError: true,
		},
		{
			opts: []spade.InitFunc{
				spade.InitHTTPClient(http.DefaultClient),
				spade.InitHTTPClientFunc(func(_ context.Context) *http.Client {
					return http.DefaultClient
				}),
			},
			expectError: true,
		},
	}
	for _, test := range tests {
		_, err := spade.NewClient(test.opts...)
		switch {
		case test.expectError && err == nil:
			t.Fail()
		case !test.expectError && err != nil:
			t.Fail()
		default:
			// success
		}
	}
}

type urlValuesTest struct {
	payload payload
	success bool
}

type testProperties struct {
	KeyS     string            `json:"key_s"`
	KeyN     int               `json:"key_n"`
	KeySlice []string          `json:"key_slice"`
	KeyMap   map[string]string `json:"key_map"`
}

func stringPointer(s string) *string {
	return &s
}

func TestURLValues(t *testing.T) {
	tests := []urlValuesTest{
		{
			payload: payload{
				Event: "some_event",
				Props: testProperties{
					KeyS: "value",
					KeyN: 200,
				},
			},
		},
	}
	for _, test := range tests {
		v, err := spade.URLValues(test.payload.Event, test.payload.Props)
		if err != nil {
			t.Errorf("error creating url values")
		}
		output, err := parsePayload(v.Get("data"))
		if err != nil {
			t.Errorf("error parsing payload")
		}
		if !payloadsMatch(output, &test.payload) {
			t.Errorf("url value payload mismatch")
		}
	}
}

type payload struct {
	Event string         `json:"event"`
	Props testProperties `json:"properties"`
}

func payloadsToEvents(p []payload) []spade.Event {
	events := make([]spade.Event, len(p))
	for i, v := range p {
		events[i] = spade.Event{
			Name:       v.Event,
			Properties: v.Props,
		}
	}
	return events
}

func payloadsMatch(a *payload, b *payload) bool {
	return a.Event == b.Event && reflect.DeepEqual(a.Props, b.Props)
}

func payloadsArrayMatch(a []payload, b []payload) bool {
	for i, v := range a {
		if !payloadsMatch(&v, &b[i]) {
			return false
		}
	}
	return true
}

func parsePayload(data string) (*payload, error) {
	jb, err := base64.URLEncoding.DecodeString(data)
	if err != nil {
		return nil, fmt.Errorf("base64 decoding: %v", err)
	}
	p := &payload{}
	err = json.Unmarshal(jb, p)
	if err != nil {
		return nil, fmt.Errorf("JSON unmarshaling: %v", err)
	}
	return p, nil
}

func parsePayloadArray(data string) (*[]payload, error) {
	jb, err := base64.URLEncoding.DecodeString(data)
	if err != nil {
		return nil, fmt.Errorf("base64 decoding: %v", err)
	}
	p := &[]payload{}
	err = json.Unmarshal(jb, p)
	if err != nil {
		return nil, fmt.Errorf("JSON unmarshaling: %v", err)
	}
	return p, nil
}

func sharedURL() url.URL {
	return url.URL{Scheme: "https", Host: "some_host", Path: "/"}
}

type trackEventTest struct {
	url     url.URL
	payload payload

	responseStatusCode int
	responseError      error
	expectError        bool
}

func TestTrackEvent(t *testing.T) {
	sharedUrl := sharedURL()
	sharedPayload := payload{
		Event: "some_event",
		Props: testProperties{
			KeyS: "some string",
			KeyN: 130,
		},
	}
	tests := []trackEventTest{
		{
			url:                sharedUrl,
			payload:            sharedPayload,
			responseStatusCode: http.StatusNoContent,
		},
		{
			url:                sharedUrl,
			payload:            sharedPayload,
			responseStatusCode: http.StatusNotFound,
			expectError:        true,
		},
		{
			url:                sharedUrl,
			payload:            sharedPayload,
			responseStatusCode: http.StatusNoContent,
			responseError:      errors.New("any error"),
			expectError:        true,
		},
	}

	for _, test := range tests {
		statHookExecutions := 0
		mrt := newMockRoundTripper(test.responseError, test.responseStatusCode)
		c, err := spade.NewClient(
			spade.InitHTTPClient(&http.Client{
				Transport: mrt,
			}),
			spade.InitBaseURL(test.url),
			spade.InitStatHook(func(string, int, time.Duration) {
				statHookExecutions += 1
			}),
		)
		if err != nil {
			t.Errorf("unexpected error creating spade client: %s", err)
		}
		err = c.TrackEvent(context.TODO(), test.payload.Event, test.payload.Props)
		switch {
		case test.expectError && err == nil:
			t.Errorf("expected error but received none")
		case !test.expectError && err != nil:
			t.Errorf("unexpected error tracking event: %s", err)
		default:
			// success
		}
		if mrt.Request == nil {
			t.Errorf("no http request attempted")
			continue
		}
		if !reflect.DeepEqual(mrt.Request.URL, &test.url) {
			t.Errorf("invalid request url %v vs %v", mrt.Request.URL, &test.url)
		}
		output, err := parsePayload(mrt.Request.PostFormValue("data"))
		if err != nil {
			t.Errorf("error parsing payload (%v): %v", mrt.Request, err)
		}
		if !payloadsMatch(output, &test.payload) {
			t.Errorf("url value payload mismatch")
		}
		if statHookExecutions != 1 {
			t.Errorf("expected 1 stat hook call, got %d", statHookExecutions)
		}
	}
}

func TestBadTrackEvent(t *testing.T) {
	statHookExecutions := 0
	mrt := &mockRoundTripper{Error: nil, Response: nil}
	c, err := spade.NewClient(
		spade.InitHTTPClient(&http.Client{
			Transport: mrt,
		}),
		spade.InitBaseURL(sharedURL()),
		spade.InitStatHook(func(string, int, time.Duration) {
			statHookExecutions += 1
		}),
	)
	if err != nil {
		t.Errorf("unexpected error creating spade client: %s", err)
	}
	err = c.TrackEvent(context.TODO(), "some_event", []testProperties{{KeyS: "some string", KeyN: 200}})
	if err == nil {
		t.Errorf("expected error but received none")
	}
	if mrt.Request != nil {
		t.Errorf("http request attempted")
	}
	if statHookExecutions != 0 {
		t.Errorf("expected no stat hook calls, got %d", statHookExecutions)
	}
}

func TestTrackEventValidation(t *testing.T) {
	payloads := []payload{
		{
			Event: "some_event",
			Props: testProperties{
				KeySlice: []string{"abcd"},
			},
		},
		{
			Event: "some_event",
			Props: testProperties{
				KeyMap: map[string]string{"abcd": "abcd"},
			},
		},
	}
	for _, p := range payloads {
		mrt := newMockRoundTripper(nil, 0)
		c, err := spade.NewClient(
			spade.InitHTTPClient(&http.Client{
				Transport: mrt,
			}),
		)
		if err != nil {
			t.Errorf("unexpected error creating spade client: %s", err)
		}
		err = c.TrackEvent(context.TODO(), p.Event, p.Props)
		if err == nil {
			t.Error("no error tracking invalid event")
		}
		if mrt.Called {
			t.Error("unexpectedly called round tripper on validation error")
		}
	}
	for _, p := range payloads {
		mrt := newMockRoundTripper(nil, http.StatusNoContent)
		c, err := spade.NewClient(
			spade.InitHTTPClient(&http.Client{
				Transport: mrt,
			}),
			spade.InitNoValidation(),
		)
		if err != nil {
			t.Errorf("unexpected error creating spade client: %s", err)
		}
		err = c.TrackEvent(context.TODO(), p.Event, p.Props)
		if err != nil {
			t.Errorf("error tracking with no validation: %s", err)
		}
		if !mrt.Called {
			t.Error("roundtripper not called with no validation")
		}
	}
}

type trackEventsTest struct {
	url     url.URL
	payload []payload

	responseStatusCode int
	responseError      error
	expectError        bool
}

func TestTrackEvents(t *testing.T) {
	sharedUrl := sharedURL()
	sharedPayload := []payload{
		{
			Event: "some_event",
			Props: testProperties{
				KeyS: "some string",
				KeyN: 130,
			},
		},
		{
			Event: "event_numero_two",
			Props: testProperties{
				KeyS: "some other string",
				KeyN: 420,
			},
		},
		{
			Event: "event_number_tree",
			Props: testProperties{
				KeyS: "another string",
				KeyN: 69,
			},
		},
	}
	tests := []trackEventsTest{
		{
			url:                sharedUrl,
			payload:            sharedPayload,
			responseStatusCode: http.StatusNoContent,
			responseError:      nil,
			expectError:        false,
		},
		{
			url:                sharedUrl,
			payload:            sharedPayload,
			responseStatusCode: http.StatusNotFound,
			responseError:      nil,
			expectError:        true,
		},
		{
			url:                sharedUrl,
			payload:            sharedPayload,
			responseStatusCode: http.StatusNoContent,
			responseError:      errors.New("any error"),
			expectError:        true,
		},
	}

	for _, test := range tests {
		statHookExecutions := 0
		mrt := newMockRoundTripper(test.responseError, test.responseStatusCode)
		c, err := spade.NewClient(
			spade.InitHTTPClient(&http.Client{
				Transport: mrt,
			}),
			spade.InitBaseURL(test.url),
			spade.InitStatHook(func(string, int, time.Duration) {
				statHookExecutions += 1
			}),
			spade.InitBatchSize(2),
		)
		if err != nil {
			t.Errorf("unexpected error creating spade client: %s", err)
		}
		events := payloadsToEvents(test.payload)
		err = c.TrackEvents(context.TODO(), events...)
		switch {
		case test.expectError && err == nil:
			t.Errorf("expected error but received none")
		case !test.expectError && err != nil:
			t.Errorf("unexpected error tracking event: %s", err)
		default:
			// success
		}
		if mrt.Request == nil {
			t.Errorf("no http request attempted")
			continue
		}
		if !reflect.DeepEqual(mrt.Request.URL, &test.url) {
			t.Errorf("invalid request url %v vs %v", mrt.Request.URL, &test.url)
		}

		firstBatch, err := parsePayloadArray(mrt.Requests[0].PostFormValue("data"))

		if err != nil {
			t.Errorf("error parsing payload")
		}
		if !payloadsArrayMatch(test.payload[0:2], *firstBatch) {
			t.Errorf("url value payload mismatch")
		}
		if test.expectError {
			if statHookExecutions != 1 {
				t.Errorf("expected 1 stat hook call, got %d", statHookExecutions)
			}
		} else {
			secondBatch, err := parsePayloadArray(mrt.Requests[1].PostFormValue("data"))
			if err != nil {
				t.Errorf("error parsing payload")
			}
			if !payloadsArrayMatch(test.payload[2:3], *secondBatch) {
				t.Errorf("url value payload mismatch")
			}
			if statHookExecutions != 2 {
				t.Errorf("expected 2 stat hook calls, got %d", statHookExecutions)
			}
		}
	}
}

func TestTrackEventsValidation(t *testing.T) {
	payloads := [][]payload{
		{
			{
				Event: "some_event",
				Props: testProperties{
					KeySlice: []string{"abcd"},
				},
			},
		}, {
			{
				Event: "some_event",
				Props: testProperties{
					KeyMap: map[string]string{"abcd": "abcd"},
				},
			},
		},
	}
	for _, p := range payloads {
		mrt := newMockRoundTripper(nil, 0)
		c, err := spade.NewClient(
			spade.InitHTTPClient(&http.Client{
				Transport: mrt,
			}),
		)
		if err != nil {
			t.Errorf("unexpected error creating spade client: %s", err)
		}
		events := payloadsToEvents(p)
		err = c.TrackEvents(context.TODO(), events...)
		if err == nil {
			t.Error("no error tracking invalid event")
		}
		if mrt.Called {
			t.Error("unexpectedly called round tripper on validation error")
		}
	}
	for _, p := range payloads {
		mrt := newMockRoundTripper(nil, http.StatusNoContent)
		c, err := spade.NewClient(
			spade.InitHTTPClient(&http.Client{
				Transport: mrt,
			}),
			spade.InitNoValidation(),
		)
		if err != nil {
			t.Errorf("unexpected error creating spade client: %s", err)
		}
		events := payloadsToEvents(p)
		err = c.TrackEvents(context.TODO(), events...)
		if err != nil {
			t.Errorf("error tracking with no validation: %s", err)
		}
		if !mrt.Called {
			t.Error("roundtripper not called with no validation")
		}
	}
}

type mockRoundTripper struct {
	Request  *http.Request
	Requests []*http.Request
	Response *http.Response
	Error    error
	Called   bool
}

var _ http.RoundTripper = (*mockRoundTripper)(nil)

func newMockRoundTripper(err error, statusCode int) *mockRoundTripper {
	mrt := mockRoundTripper{Error: err}
	if err == nil {
		mrt.Response = &http.Response{
			StatusCode: statusCode,
			Body:       ioutil.NopCloser(bytes.NewBufferString("")),
		}
	}
	return &mrt
}

func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
	m.Request = req
	m.Requests = append(m.Requests, req)
	m.Called = true
	return m.Response, m.Error
}
