//nolint:bodyclose
package oidc

import (
	"bytes"
	"context"
	"crypto/ecdsa"
	"crypto/elliptic"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"math/big"
	"net/http"
	"net/url"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchS2S2/c7s"
	"code.justin.tv/amzn/TwitchS2S2/internal/oidc/mocks"
	"code.justin.tv/amzn/TwitchS2SJWTAlgorithms/es256"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
)

func TestOIDC(t *testing.T) {
	ctx := context.Background()

	t.Run("init", func(t *testing.T) {
		onGet := func(ot *oidcTest) *mock.Call {
			return ot.Client.
				On("Do", mock.MatchedBy(func(req *http.Request) bool {
					return req.Method == "GET" && req.URL.String() == ot.Config.DiscoveryEndpoint
				}))
		}

		t.Run("success", func(t *testing.T) {
			ot := newOIDCTest()
			defer ot.Teardown(t)

			cfg := Configuration{
				IssuerID:              "ISSUERID",
				AuthorizationEndpoint: "AUTHORIZATIONENDPOINT",
				TokenEndpoint:         "TOKENENDPOINT",
				JwksURI:               "JWKSURI",
			}

			onGet(ot).
				Return(ot.JSONResponse(t, cfg), nil).
				Once()

			require.NoError(t, ot.OIDC.init(context.TODO()))
			assert.Equal(t, ot.OIDC.oidcConfig, &cfg)
		})

		t.Run("get error", func(t *testing.T) {
			ot := newOIDCTest()
			defer ot.Teardown(t)

			getErr := errors.New("get error")

			onGet(ot).
				Return(nil, getErr).
				Once()

			assert.Equal(t, getErr, ot.OIDC.init(context.TODO()))
		})

		t.Run("status code incorrect", func(t *testing.T) {
			ot := newOIDCTest()
			defer ot.Teardown(t)

			httpResponse := &http.Response{
				Body:       ioutil.NopCloser(bytes.NewBufferString("some error")),
				Header:     make(http.Header),
				StatusCode: http.StatusInternalServerError,
			}

			onGet(ot).Return(httpResponse, nil)

			assert.Equal(t, fmt.Errorf("error reading from discovery endpoint: some error"), ot.OIDC.init(context.TODO()))
		})

		t.Run("invalid json", func(t *testing.T) {
			ot := newOIDCTest()
			defer ot.Teardown(t)

			onGet(ot).
				Return(ot.JSONResponse(t, struct {
					IssuerID []string `json:"issuer"`
				}{IssuerID: []string{"wrongtype"}}), nil).
				Once()

			assert.IsType(t, &json.UnmarshalTypeError{}, ot.OIDC.init(context.TODO()))
		})
	})

	t.Run("Configuration", func(t *testing.T) {
		t.Run("success", func(t *testing.T) {
			ot := newOIDCTest()
			defer ot.Teardown(t)

			cfg := Configuration{
				IssuerID:              "ISSUERID",
				AuthorizationEndpoint: "AUTHORIZATIONENDPOINT",
				TokenEndpoint:         "TOKENENDPOINT",
				JwksURI:               "JWKSURI",
			}

			ot.OIDC.oidcConfig = &cfg

			assert.Equal(t, &cfg, ot.OIDC.Configuration())
		})
	})

	t.Run("ValidationKeys", func(t *testing.T) {
		const keyID = "KEYID"
		const jwksURI = "https://my.jwk/uri"

		key := &es256.PublicKey{PublicKey: &ecdsa.PublicKey{Curve: elliptic.P256(), X: big.NewInt(11), Y: big.NewInt(13)}, KeyID: keyID}

		onGet := func(ot *oidcTest) *mock.Call {
			ot.OIDC.oidcConfig = &Configuration{JwksURI: jwksURI}
			return ot.Client.
				On("Do", mock.MatchedBy(func(req *http.Request) bool {
					return req.Method == "GET" && req.URL.String() == jwksURI
				}))
		}

		t.Run("success", func(t *testing.T) {
			ot := newOIDCTest()
			defer ot.Teardown(t)

			httpResponse := ot.JSONResponse(t, struct {
				Keys []*es256.PublicKey `json:"keys"`
			}{Keys: []*es256.PublicKey{key}})
			httpResponse.Header.Set("Cache-Control", "max-age=3600")

			onGet(ot).Return(httpResponse, nil)

			res, dt, err := ot.OIDC.ValidationKeys(ctx)
			require.NoError(t, err)
			assert.Equal(t, map[string]ValidationKey{keyID: key}, res)
			assert.Equal(t, time.Hour, dt)
		})

		t.Run("get error", func(t *testing.T) {
			ot := newOIDCTest()
			defer ot.Teardown(t)

			httpErr := errors.New("HTTPERR")
			onGet(ot).Return(nil, httpErr)

			_, _, err := ot.OIDC.ValidationKeys(ctx)
			assert.Equal(t, httpErr, err)
		})

		t.Run("status code incorrect", func(t *testing.T) {
			ot := newOIDCTest()
			defer ot.Teardown(t)

			httpResponse := &http.Response{
				Body:       ioutil.NopCloser(bytes.NewBufferString("some error")),
				Header:     make(http.Header),
				StatusCode: http.StatusInternalServerError,
			}

			onGet(ot).Return(httpResponse, nil)

			_, _, err := ot.OIDC.ValidationKeys(ctx)
			assert.Equal(t, fmt.Errorf("error getting validation keys: some error"), err)
		})

		t.Run("cache control parse error", func(t *testing.T) {
			ot := newOIDCTest()
			defer ot.Teardown(t)

			httpResponse := ot.JSONResponse(t, struct {
				Keys []*es256.PublicKey `json:"keys"`
			}{Keys: []*es256.PublicKey{key}})
			httpResponse.Header.Set("Cache-Control", "max-age=infinity")

			onGet(ot).Return(httpResponse, nil)

			_, _, err := ot.OIDC.ValidationKeys(ctx)
			assert.Equal(t, errInvalidCacheControlHeader, err)
		})

		t.Run("json decode error", func(t *testing.T) {
			ot := newOIDCTest()
			defer ot.Teardown(t)

			httpResponse := ot.JSONResponse(t, struct {
				Keys string `json:"keys"`
			}{Keys: "wrongformat"})
			httpResponse.Header.Set("Cache-Control", "max-age=3600")

			onGet(ot).Return(httpResponse, nil)

			_, _, err := ot.OIDC.ValidationKeys(ctx)
			assert.IsType(t, &json.UnmarshalTypeError{}, err)
		})
	})
}

func newOIDCTest() *oidcTest {
	config := &c7s.Config{
		Issuer:            "https://my.issuer",
		DiscoveryEndpoint: "https://my.issuer/my/discovery",
	}
	client := new(mocks.HTTPClient)
	return &oidcTest{
		OIDC: &OIDC{
			Config: config,
			Client: client,
		},
		Config: config,
		Client: client,
	}
}

type oidcTest struct {
	OIDC   *OIDC
	Config *c7s.Config
	Client *mocks.HTTPClient
}

func (ot *oidcTest) Teardown(t *testing.T) {
	ot.Client.AssertExpectations(t)
}

func (ot *oidcTest) JSONResponse(t *testing.T, body interface{}) *http.Response {
	var buf bytes.Buffer
	require.NoError(t, json.NewEncoder(&buf).Encode(body))
	return &http.Response{
		Body:       ioutil.NopCloser(&buf),
		Header:     make(http.Header),
		StatusCode: http.StatusOK,
	}
}

func (ot *oidcTest) IssuerURL(t *testing.T) *url.URL {
	url, err := url.Parse(ot.Config.Issuer)
	require.NoError(t, err)
	return url
}
