/* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. */

package awsjsonv1_test

import (
	"CoralGoCodec/codec"
	"CoralRPCGoSupport/awsjsonv1"
	"CoralRPCGoSupport/internal/roundtrip"
	"CoralRPCGoSupport/internal/test/fake"
	"CoralRPCGoSupport/internal/test/fakemodel"
	"aaa"
	"bufio"
	"bytes"
	"io"
	"net/http"
	"net/http/httptest"
	"reflect"
	"strings"
	"testing"

	"github.com/google/go-cmp/cmp"
	"github.com/pkg/errors"
)

const (
	validHttpResponse = `HTTP/1.1 200 OK
x-amzn-RequestId: 954ec7e5-f4df-11e4-a33e-535b6bee04a1
Content-Type: application/x-amz-json-1.0
Content-Length: 45
X-Amz-Date: Thu, 01 Mar 2012 22:07:13 GMT

{
    "__type": "FakeOutput","output":"1.0"
}`

	validHttpResponseExt = `HTTP/1.1 200 OK
x-amzn-RequestId: 954ec7e5-f4df-11e4-a33e-535b6bee04a1
Content-Type: application/x-amz-json-1.0
Content-Length: 115
X-Amz-Date: Thu, 01 Mar 2012 22:07:13 GMT

{
    "__type": "FakeOutputExtension","output":"1.0",
    "ext": {"__type": "FakeOutputExtension","output":"ext"}
}`

	emptyHttpResponse = `HTTP/1.1 200 OK
x-amzn-RequestId: 954ec7e5-f4df-11e4-a33e-535b6bee04a1
Content-Type: application/x-amz-json-1.0
Content-Length: 0
X-Amz-Date: Thu, 01 Mar 2012 22:07:13 GMT`

	invalidJsonHttpResponse = `HTTP/1.1 200 OK
x-amzn-RequestId: 954ec7e5-f4df-11e4-a33e-535b6bee04a1
Content-Type: application/x-amz-json-1.0
Content-Length: 1
Date: Thu, 07 May 2015 17:36:24 GMT

{`

	non200HttpResponse = `HTTP/1.1 500 Internal Server Error
x-amzn-RequestId: 954ec7e5-f4df-11e4-a33e-535b6bee04a1
Content-Type: application/x-amz-json-1.0
Content-Length: 45
X-Amz-Date: Thu, 01 Mar 2012 22:07:13 GMT

{
    "__type": "FakeOutput","output":"1.0"
}`

	validHttpRequest = `POST https://example.com:8000 HTTP/1.1
Content-Type: application/x-amz-json-1.0; charset=UTF-8
Content-Length: 232
X-Amz-Date: Thu, 01 Mar 2012 22:07:13 GMT
X-Amz-Target: TopListsService.IsAuthorized

{
   "__type": "FakeInput",
   "user": "example",
   "rules": [ {
      "__type": "FakeRule",
      "source": "LDAP", "identifier": "foo"}
   ],
   "other": {"__type": "FakeOther", "source": "turn_down", "identifier": "for_what"}
}
`
	httpRequestEmpty = `POST https://example.com:8000 HTTP/1.1
Content-Type: application/x-amz-json-1.0; charset=UTF-8
Content-Length: 0
X-Amz-Date: Thu, 01 Mar 2012 22:07:13 GMT
X-Amz-Target: TopListsService.IsAuthorized

`

	httpRequestNoTarget = `POST https://example.com:8000 HTTP/1.1
Content-Type: application/x-amz-json-1.0; charset=UTF-8
Content-Length: 152
X-Amz-Date: Thu, 01 Mar 2012 22:07:13 GMT

{
   "__type": "FakeInput",
   "user": "example",
   "rules": [ {
         "__type": "FakeRule",
         "source": "LDAP", "identifier": "foo"}
   ]
}
`
	httpRequestInvalidJson = `POST https://example.com:8000 HTTP/1.1
Content-Type: application/x-amz-json-1.0
Content-Length: 115
X-Amz-Date: Thu, 01 Mar 2012 22:07:13 GMT
X-Amz-Target: TestService.IsAuthorized

{
   "__type": "FakeInput"
   "user": "example"
   "rules": [ {
      "__type": "FakeRule",
      "source": "LDAP", "identifier": "foo"}
   ]
}
`
)

func TestIsSupported(t *testing.T) {
	ct1 := "application/x-amz-json-1.0; charset=UTF-8;"
	ct2 := "application/x-amz-json-1.0"
	ct3 := "application/xml"
	dt := "2016"
	target := "asm.service.op"
	v1, _ := awsjsonv1.New()
	tests := []struct {
		name      string
		headers   http.Header
		supported bool
	}{
		{"matching request", headers(ct1, dt, target), true},
		{"no charset", headers(ct2, dt, target), true},
		{"nil", nil, false},
		{"empty", http.Header{}, false},
		{"missing content type", headers("", dt, target), false},
		{"missing date", headers(ct1, "", target), false},
		{"missing target", headers(ct1, dt, ""), false},
		{"bad content type", headers(ct3, dt, target), false},
	}

	for _, test := range tests {
		if v1.IsSupported(test.headers) != test.supported {
			t.Error(test.name, "- Expected", test.supported, "for headers", test.headers)
		}
	}
}

func TestRoundTrip(t *testing.T) {
	tests := map[string]struct {
		aaaDecodeErr     error
		aaaEncodeErr     error
		out              interface{}
		rw               io.ReadWriter
		wantError        error
		wantErrorContext func(request *codec.Request) roundtrip.Context
		validateOutput   func(t *testing.T, out interface{})
	}{
		// Success Cases
		"non_200_response": {
			out: fakemodel.NewFakeOutput(),
			rw:  &fake.Pipe{ReadBuf: bytes.NewBufferString(non200HttpResponse)},
		},
		"basic_success": {
			out: fakemodel.NewFakeOutput(),
			rw:  &fake.Pipe{ReadBuf: bytes.NewBufferString(validHttpResponse)},
			validateOutput: func(t *testing.T, out interface{}) {
				want := "1.0"
				if got := out.(fakemodel.FakeOutput).Output(); got != want {
					t.Errorf("output mismatch: got: %s, want: %s", got, want)
				}
			},
		},
		// Validate what happens when we get an extension to the output type.
		"basic_success_with_extension": {
			out: func() interface{} { out := fakemodel.NewFakeOutput(); return &out }(),
			rw:  &fake.Pipe{ReadBuf: bytes.NewBufferString(validHttpResponseExt)},
			validateOutput: func(t *testing.T, out interface{}) {
				if outputExt, ok := reflect.ValueOf(out).Elem().Interface().(fakemodel.FakeOutputExtension); ok {
					want := "ext"
					if got := outputExt.Ext().Output(); got != want {
						t.Errorf("output mismatch: got: %s, want: %s", got, want)
					}
				} else {
					t.Errorf("output type mismatch: got: %T, want: FakeOutputExtension", outputExt)
				}
			},
		},
		// Not all Clients specify an output for each input
		"no_output": {
			out: nil,
			rw:  &fake.Pipe{ReadBuf: bytes.NewBufferString(validHttpResponse)},
			validateOutput: func(t *testing.T, out interface{}) {
				if out != nil {
					t.Errorf("output mismatch: got: %#v, want: nil", out)
				}
			},
		},

		// Error Cases
		"AAA_encode_failure": {
			aaaEncodeErr: errors.New("test error"),
			out:          fakemodel.NewFakeOutput(),
			rw:           &fake.Pipe{ReadBuf: bytes.NewBufferString(validHttpResponse)},
			wantError:    errors.New("failed to AAA encode request: test error"),
			wantErrorContext: func(request *codec.Request) roundtrip.Context {
				return roundtrip.Context{Request: request}
			},
		},
		"AAA_decode_failure": {
			aaaDecodeErr: errors.New("test error"),
			out:          fakemodel.NewFakeOutput(),
			rw:           &fake.Pipe{ReadBuf: bytes.NewBufferString(validHttpResponse)},
			wantError:    errors.New("failed to AAA decode response: test error"),
			wantErrorContext: func(request *codec.Request) roundtrip.Context {
				return roundtrip.Context{
					Request:            request,
					RequestID:          "req_id_AAA_decode_failure",
					ResponseStatusCode: 200,
				}
			},
		},
		"empty_response": {
			out:       fakemodel.NewFakeOutput(),
			rw:        &fake.Pipe{ReadBuf: bytes.NewBufferString("")},
			wantError: errors.New("failed to read http response: unexpected EOF"),
			wantErrorContext: func(request *codec.Request) roundtrip.Context {
				return roundtrip.Context{
					Request:   request,
					RequestID: "req_id_empty_response",
				}
			},
		},
		"empty_body": {
			out:       fakemodel.NewFakeOutput(),
			rw:        &fake.Pipe{ReadBuf: bytes.NewBufferString(emptyHttpResponse)},
			wantError: errors.New("failed to read http response: unexpected EOF"),
			wantErrorContext: func(request *codec.Request) roundtrip.Context {
				return roundtrip.Context{
					Request:   request,
					RequestID: "req_id_empty_body",
				}
			},
		},
		"invalid_body": {
			out:       fakemodel.NewFakeOutput(),
			rw:        &fake.Pipe{ReadBuf: bytes.NewBufferString(invalidJsonHttpResponse)},
			wantError: errors.New("failed to convert response to a map: Failed call to json.Unmarshal: unexpected end of JSON input"),
			wantErrorContext: func(request *codec.Request) roundtrip.Context {
				return roundtrip.Context{
					Request:            request,
					RequestID:          "req_id_invalid_body",
					ResponseStatusCode: 200,
				}
			},
		},
	}

	for name, test := range tests {
		name, test := name, test
		t.Run(name, func(t *testing.T) {
			t.Parallel()
			fakeAAA := &fake.AAA{DecodeErr: test.aaaDecodeErr, EncodeErr: test.aaaEncodeErr}
			rpc, errNew := awsjsonv1.New(awsjsonv1.WithAAAForClient(fakeAAA), awsjsonv1.WithRequestIdGenerator(func() (string, error) {
				return "req_id_" + name, nil
			}))
			if errNew != nil {
				t.Fatalf("Unable to create instance of awsjsonv1: %+v", errNew)
			}

			req := &codec.Request{
				Service:   fakemodel.TestService,
				Operation: fakemodel.TestOperation,
				Input:     fakemodel.NewFakeInput(),
				Output:    test.out,
			}

			err := rpc.RoundTrip(req, test.rw)
			if gotErr, wantErr := errorValue(err), errorValue(test.wantError); gotErr != wantErr {
				t.Errorf("error mismatch: got/want: \n%s\n%s", gotErr, wantErr)
			}

			if test.wantError != nil {
				rte, ok := err.(*roundtrip.Error)

				if !ok {
					t.Fatalf("error type mismatch: got: %T, want: *rpcv1.RoundTripError", err)
				}

				wantCtx := test.wantErrorContext(req)
				if diff := cmp.Diff(wantCtx, rte.Context); diff != "" {
					t.Errorf("context mismatch: (-expected, +found)\n%s", diff)
				}
			}

			if test.validateOutput == nil {
				return
			}

			test.validateOutput(t, test.out)
		})
	}
}

func TestGetServiceAnOp(t *testing.T) {
	tests := []struct {
		name    string
		val     string
		service string
		op      string
	}{
		{name: "empty string"},
		{name: "no period", val: "service_op"},
		{name: "period at end", val: "service.op."},
		{name: "simple case", val: "service.op", service: "service", op: "op"},
		{name: "multiple periods", val: "s.e.r.v.i.c.e.op", service: "s.e.r.v.i.c.e", op: "op"},
	}

	for _, test := range tests {
		if service, op := awsjsonv1.GetServiceAndOp(test.val); service != test.service || op != test.op {
			t.Error(test.name, "- Expected (", test.service, ",", test.op,
				") found (", service, ",", op, ")")
		}
	}
}

func TestUnmarshalRequest(t *testing.T) {
	rpc, _ := awsjsonv1.New()
	other := fakemodel.NewFakeOther()
	other.SetSource("turn_down")
	other.SetIdentifier("for_what")
	rule := fakemodel.NewFakeRule()
	rule.SetIdentifier("foo")
	rule.SetSource("LDAP")
	rules := []fakemodel.FakeRule{
		rule,
	}
	expectedInput := fakemodel.NewFakeInput()
	expectedInput.SetUser("example")
	expectedInput.SetRules(rules)
	expectedInput.SetOther(other)

	req := buildRequest(validHttpRequest, t)
	unmarshalled, err := rpc.UnmarshalRequest(req)
	switch {
	case err != nil:
		t.Fatalf("Unexpected error %+v\nFor request %v", err, req)
	case unmarshalled.Service != fakemodel.TestService || unmarshalled.Operation != fakemodel.TestOperation:
		t.Error("Expected", fakemodel.TestService, fakemodel.TestOperation, "but found", unmarshalled.Service)
	case unmarshalled.Input == nil || unmarshalled.Output == nil:
		t.Fatal("Expected both input and output to be present but found", unmarshalled.Input, unmarshalled.Output)
	}

	actualInput, ok := unmarshalled.Input.(fakemodel.FakeInput)
	if !ok {
		t.Fatalf("Expected input to be of type FakeInput, but found %T", unmarshalled.Input)
	}
	if diff := cmp.Diff(expectedInput, actualInput); diff != "" {
		t.Errorf("(-expected +found)\n%s", diff)
	}
}

func TestUnmarshalRequest_ARPS(t *testing.T) {
	rpc, _ := awsjsonv1.New(awsjsonv1.WithARPSAuthorizer(&fake.ARPSAuthorizer))
	req := buildRequest(validHttpRequest, t)
	if err := fake.Signer.Sign(req); err != nil {
		t.Error("Unexpected error", err)
	}
	if _, err := rpc.UnmarshalRequest(req); err != nil {
		t.Error("Unexpected error", err)
	}
}

func TestUnmarshalRequest_AAA(t *testing.T) {
	fakeAAA := &fake.AAA{AuthResult: &aaa.AuthorizationResult{Authorized: true}}
	rpc, _ := awsjsonv1.New(awsjsonv1.WithAAAForServer(fakeAAA))
	rule := fakemodel.NewFakeRule()
	rule.SetIdentifier("foo")
	rule.SetSource("LDAP")
	rules := []fakemodel.FakeRule{
		rule,
	}
	other := fakemodel.NewFakeOther()
	other.SetSource("turn_down")
	other.SetIdentifier("for_what")
	expectedInput := fakemodel.NewFakeInput()
	expectedInput.SetUser("example")
	expectedInput.SetRules(rules)
	expectedInput.SetOther(other)
	req := buildRequest(validHttpRequest, t)

	request, err := rpc.UnmarshalRequest(req)
	if err != nil {
		t.Fatalf("Unexpected error: %+v", err)
	}
	if request.Service != fakemodel.TestService || request.Operation != fakemodel.TestOperation {
		t.Error("Expected", fakemodel.TestService, fakemodel.TestOperation, "but found", request.Service, request.Operation)
	}
	if request.Input == nil || request.Output == nil {
		t.Fatal("Expected both input and output to be present but found", request.Input, request.Output)
	}
	if in, ok := request.Input.(fakemodel.FakeInput); !ok {
		t.Errorf("Expected input to be of type FakeInput, but found %T", request.Input)
	} else if diff := cmp.Diff(expectedInput, in); diff != "" {
		t.Errorf("(-expected +found)\n%s", diff)
	}
}

func TestUnmarshalRequest_Error(t *testing.T) {
	fakeAAA := &fake.AAA{AuthResult: &aaa.AuthorizationResult{Authorized: true}}
	rpc, _ := awsjsonv1.New(awsjsonv1.WithAAAForServer(fakeAAA))

	// Request error cases
	errorCases := []string{httpRequestEmpty, httpRequestNoTarget, httpRequestInvalidJson}
	for _, errorCase := range errorCases {
		assertUnmarshalRequestError(rpc, errorCase, errorCase, t)
	}

	// AAA error cases.
	err := errors.New("Boom")
	fakeAAA.AuthErr = err
	assertUnmarshalRequestError(rpc, "authError", validHttpRequest, t)
	fakeAAA.AuthErr, fakeAAA.DecodeErr = nil, err
	assertUnmarshalRequestError(rpc, "decodeError", validHttpRequest, t)
	fakeAAA.DecodeErr, fakeAAA.AuthResult = nil, &aaa.AuthorizationResult{Authorized: false}
	assertUnmarshalRequestError(rpc, "decodeError", validHttpRequest, t)
}

func TestMarshalResponse(t *testing.T) {
	rpc, _ := awsjsonv1.New()

	out := fakemodel.NewFakeOutput()
	out.SetOutput("I out all the puts")
	// "Unable to process request" = 26 characters.
	tests := []struct {
		name       string
		req        *codec.Request
		statusCode int
		bodyLen    int
	}{
		{"nil request", nil, http.StatusInternalServerError, 26},
		{"empty operation", &codec.Request{Service: fakemodel.TestService}, http.StatusInternalServerError, 26},
		{"empty service", &codec.Request{Operation: fakemodel.TestOperation}, http.StatusInternalServerError, 26},
		{"empty output", &codec.Request{Service: fakemodel.TestService, Operation: fakemodel.TestOperation}, http.StatusNoContent, 0},
		{"invalid output type", &codec.Request{Service: fakemodel.TestService, Operation: fakemodel.TestOperation, Output: 0}, http.StatusInternalServerError, 27},
		{"valid", &codec.Request{Service: fakemodel.TestService, Operation: fakemodel.TestOperation, Output: out}, http.StatusOK, 64},
	}

	for _, test := range tests {
		w := httptest.NewRecorder()
		rpc.MarshalResponse(w, test.req)
		if w.Code != test.statusCode {
			t.Error(test.name, "- Expected status code", test.statusCode, "but found", w.Code)
		}
		if w.Body.Len() != test.bodyLen {
			t.Error(test.name, "- Expected body of length", test.bodyLen, "but found", w.Body.Len(), ":\n", w.Body.String())
		}
	}
}

func TestMarshalResponse_AAA(t *testing.T) {
	fakeAAA := &fake.AAA{}
	rpc, _ := awsjsonv1.New(awsjsonv1.WithAAAForServer(fakeAAA))
	sctx := &aaa.ServiceContext{Service: fakemodel.TestService.ShapeName, Operation: fakemodel.TestOperation.ShapeName}

	out := fakemodel.NewFakeOutput()
	out.SetOutput("I out all the puts")
	// "Unable to process request" = 26 characters.
	tests := []struct {
		name       string
		req        *codec.Request
		aaaErr     error
		statusCode int
		bodyLen    int
	}{
		{"nil request", nil, nil, http.StatusInternalServerError, 26},
		{"empty operation", &codec.Request{Service: fakemodel.TestService, AuthCtx: sctx}, nil, http.StatusInternalServerError, 26},
		{"empty service", &codec.Request{Operation: fakemodel.TestOperation, AuthCtx: sctx}, nil, http.StatusInternalServerError, 26},
		{"empty output", &codec.Request{Service: fakemodel.TestService, Operation: fakemodel.TestOperation, AuthCtx: sctx},
			nil, http.StatusNoContent, 0,
		},
		{"invalid output type", &codec.Request{Service: fakemodel.TestService, Operation: fakemodel.TestOperation, AuthCtx: sctx, Output: 0},
			nil, http.StatusInternalServerError, 27,
		},
		{"aaa encode error", &codec.Request{Service: fakemodel.TestService, Operation: fakemodel.TestOperation, AuthCtx: sctx, Output: out},
			errors.New("Boom"), http.StatusInternalServerError, 26,
		},
		{"valid", &codec.Request{Service: fakemodel.TestService, Operation: fakemodel.TestOperation, AuthCtx: sctx, Output: out},
			nil, http.StatusOK, 64,
		},
	}

	for _, test := range tests {
		fakeAAA.EncodeErr = test.aaaErr
		w := httptest.NewRecorder()
		rpc.MarshalResponse(w, test.req)
		if w.Code != test.statusCode {
			t.Error(test.name, "- Expected status code", test.statusCode, "but found", w.Code)
		}
		if w.Body.Len() != test.bodyLen {
			t.Error(test.name, "- Expected body of length", test.bodyLen, "but found", w.Body.Len(), ":\n", w.Body.String())
		}
	}
}

// Utilities

func assertUnmarshalRequestError(rpc awsjsonv1.AWSJSONv1, name, reqBody string, t *testing.T) {
	t.Helper()
	req := buildRequest(reqBody, t)
	if request, err := rpc.UnmarshalRequest(req); err == nil {
		t.Error("Expected error for ", name, "but found", request)
	}
}

func headers(ct, dt, tar string) http.Header {
	h := http.Header{}
	h.Set("Content-Type", ct)
	h.Set("X-Amz-Date", dt)
	h.Set("X-Amz-Target", tar)
	return h
}

// errorValue returns "<nil>" if the error is nil and output of the error's Error() if non-nil.
func errorValue(err error) string {
	if err == nil {
		return "<nil>"
	}
	return err.Error()
}

// buildRequest uses http.ReadRequest to build an http.Request.  If ReadRequest reports an
// error then t.Error is called and nil is returned.
func buildRequest(reqBody string, t *testing.T) *http.Request {
	t.Helper()
	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(reqBody)))
	if err == nil {
		return req
	}
	t.Fatal("Unable to parse", reqBody, "\n", err)
	return nil
}
