package s2s2dicallee

import (
	"bytes"
	"encoding/json"
	"errors"
	"io/ioutil"
	"net/http"
	"net/url"
	"testing"
	"time"

	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/cachedvalidation"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/cert"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/logutil"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/internal/validation"
	"code.justin.tv/amzn/TwitchS2S2DistributedIdentitiesCallee/s2s2dicallee/mocks"
	"code.justin.tv/video/metrics-middleware/v2/operation"
	"github.com/golang/mock/gomock"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestNewCallee(t *testing.T) {
	assert.NotEmpty(t, newCallee(&cachedvalidation.CachedValidations{}, &cert.Certificates{}, time.NewTicker(time.Minute), &logutil.Logger{}))
}

func TestNewValidations(t *testing.T) {
	t.Run("success", func(t *testing.T) {
		res, err := newValidations(&Options{
			WebOrigins: []string{"https://derp.herp"},
		}, &cert.Certificates{}, &operation.Starter{})
		require.NoError(t, err)
		assert.NotEmpty(t, res)
	})

	t.Run("no origins", func(t *testing.T) {
		_, err := newValidations(&Options{}, nil, &operation.Starter{})
		assert.Contains(t, err.Error(), "origins defined must be")
	})
}

func TestNewCachedValidations(t *testing.T) {
	assert.NotEmpty(t, newCachedValidations(&validation.Validations{}))
}

func TestNewCertificates(t *testing.T) {
	jsonResponse := func(t *testing.T, in interface{}) *http.Response {
		var bs bytes.Buffer
		require.NoError(t, json.NewEncoder(&bs).Encode(in))
		return &http.Response{
			Body:       ioutil.NopCloser(&bs),
			StatusCode: http.StatusOK,
		}
	}

	rateLimiterTestValue := time.Millisecond

	t.Run("success", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		roundTripper := mocks.NewMockRoundTripper(ctrl)
		roundTripper.EXPECT().RoundTrip(gomock.Any()).
			Return(jsonResponse(t, struct{}{}), nil)
		roundTripper.EXPECT().RoundTrip(gomock.Any()).
			Return(jsonResponse(t, struct{}{}), nil)

		res, err := newCertificates(
			&Options{
				AuthorizedServices: []Service{
					{},
				},
			},
			&http.Client{Transport: roundTripper},
			servicesDomain("twitch"),
			certificateStoreOrigin{Scheme: "https", Host: "certs"},
			&operation.Starter{},
			logutil.NoopLogger,
			rateLimiterTestValue,
		)
		require.NoError(t, err)
		assert.NotEmpty(t, res)
	})

	t.Run("less than one authorization is fine", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		roundTripper := mocks.NewMockRoundTripper(ctrl)
		roundTripper.EXPECT().RoundTrip(gomock.Any()).
			Return(jsonResponse(t, struct{}{}), nil)

		res, err := newCertificates(
			&Options{},
			&http.Client{Transport: roundTripper},
			servicesDomain("twitch"),
			certificateStoreOrigin{Scheme: "https", Host: "certs"},
			&operation.Starter{},
			logutil.NoopLogger,
			rateLimiterTestValue,
		)
		require.NoError(t, err)
		assert.NotEmpty(t, res)
	})

	t.Run("load cert pool error", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		myErr := errors.New("myerr")
		roundTripper := mocks.NewMockRoundTripper(ctrl)
		roundTripper.EXPECT().RoundTrip(gomock.Any()).
			Return(nil, myErr)

		_, err := newCertificates(
			&Options{
				AuthorizedServices: []Service{
					{},
				},
			},
			&http.Client{Transport: roundTripper},
			servicesDomain("twitch"),
			certificateStoreOrigin{Scheme: "https", Host: "certs"},
			&operation.Starter{},
			logutil.NoopLogger,
			rateLimiterTestValue,
		)
		assert.Contains(t, err.Error(), "myerr")
	})

	t.Run("load service cert pool error", func(t *testing.T) {
		ctrl := gomock.NewController(t)
		defer ctrl.Finish()

		myErr := errors.New("myerr")
		roundTripper := mocks.NewMockRoundTripper(ctrl)
		roundTripper.EXPECT().RoundTrip(gomock.Any()).
			Return(jsonResponse(t, struct{}{}), nil)
		roundTripper.EXPECT().RoundTrip(gomock.Any()).
			Return(nil, myErr)

		_, err := newCertificates(
			&Options{
				AuthorizedServices: []Service{
					{},
				},
			},
			&http.Client{Transport: roundTripper},
			servicesDomain("twitch"),
			certificateStoreOrigin{Scheme: "https", Host: "certs"},
			&operation.Starter{},
			logutil.NoopLogger,
			rateLimiterTestValue,
		)
		assert.Contains(t, err.Error(), "myerr")
	})
}

func TestNewHTTPClient(t *testing.T) {
	assert.Equal(t, &http.Client{}, newHTTPClient())
}

func TestNewServicesDomain(t *testing.T) {
	t.Run("default", func(t *testing.T) {
		assert.Equal(t, servicesDomain("twitch"), newServicesDomain(&Options{}))
	})

	t.Run("set", func(t *testing.T) {
		assert.Equal(t, servicesDomain("derp"), newServicesDomain(&Options{ServicesDomain: "derp"}))
	})
}

func TestNewCertificateStoreOrigin(t *testing.T) {
	t.Run("default", func(t *testing.T) {
		assert.Equal(t, certificateStoreOrigin{
			Scheme: "https",
			Host:   "prod.s2s2identities.twitch.a2z.com",
		}, newCertificateStoreOrigin(&Options{}))
	})

	t.Run("set", func(t *testing.T) {
		assert.Equal(t, certificateStoreOrigin{
			Scheme: "https",
			Host:   "something",
		}, newCertificateStoreOrigin(&Options{
			CertificateStoreOrigin: &url.URL{
				Scheme: "https",
				Host:   "something",
			},
		}))
	})
}

func TestNewRefreshRateLimiter(t *testing.T) {
	assert.NotEmpty(t, newRefreshRateLimiter())
}

func TestNewUnknownCallerLoggerRateLimit(t *testing.T) {
	t.Run("default", func(t *testing.T) {
		options := &Options{
			UnknownCallerLoggerRateLimit: 0,
		}
		duration := newUnknownCallerLoggerRateLimit(options)
		assert.Equal(t, 15*time.Minute, duration)
	})

	t.Run("Custom duration", func(t *testing.T) {
		options := &Options{
			UnknownCallerLoggerRateLimit: 2 * time.Minute,
		}
		duration := newUnknownCallerLoggerRateLimit(options)
		assert.Equal(t, 2*time.Minute, duration)
	})

	t.Run("Disabled", func(t *testing.T) {
		options := &Options{
			UnknownCallerLoggerRateLimit: -1,
		}
		duration := newUnknownCallerLoggerRateLimit(options)
		assert.Equal(t, time.Duration(-1), duration)
	})
}
