package session

import (
	"code.justin.tv/qe/grid_router/src/pkg/config"
	"code.justin.tv/qe/grid_router/src/pkg/hub_registry"
	"code.justin.tv/qe/grid_router/src/pkg/instrumentor"
	"code.justin.tv/qe/grid_router/src/pkg/instrumentor/mocks"
	"errors"
	"fmt"
	"github.com/alicebob/miniredis/v2"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/cloudwatch"
	"github.com/go-redis/redis"
	"github.com/jonboulle/clockwork"
	"github.com/sirupsen/logrus"
	"github.com/sirupsen/logrus/hooks/test"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"net/url"
	"testing"
	"time"
)

func TestHandler_Handle(t *testing.T) {
	hubServerRequests := 0
	sessionPath := "/wd/hub/session"
	var header string
	internalSessionID := "testSID" // used as the session id that is mocked back as created

	// Set up a mock hub server that pretends to create a session
	hubServerNewSession := createMockHubServer(&hubServerRequests, nil, &header, internalSessionID)
	defer hubServerNewSession.Close()

	// Additional Setup, run a redis server, set up the registry
	s, err := miniredis.Run()
	defer s.Close()
	require.NoError(t, err)

	redisClient := redis.NewClient(&redis.Options{
		Addr: s.Addr(),
	})
	defer redisClient.Close()

	// Create a registry
	appConfig := config.NewMock()
	reg := hub_registry.NewRegistry(nil, redisClient, appConfig, time.Hour)

	// Create api keys
	apiKeys := []string{"test1"}

	hubServerUrl, err := url.Parse(hubServerNewSession.URL)
	require.NoError(t, err)

	// Create a mock hub with the host and port from the test server above
	mockHub := hub_registry.Hub{
		ID: "i-1234",
		IP: hubServerUrl.Hostname(),
		Port: hubServerUrl.Port(),
		Healthy: true,
		Paused: false,
		SlotCounts: hub_registry.SlotCounts{
			Free: 1,
			Total: 1,
		},
	}

	err = reg.SaveHub(&mockHub)
	require.NoError(t, err)

	// Set up the handler
	handler := NewHandler(reg, apiKeys, appConfig)

	t.Run("Does not forward the request if unauthorized", func(t *testing.T) {
		// Test multiple methods as they may have different flows
		var authTests = []string{
			http.MethodGet,
			http.MethodPost,
			http.MethodDelete,
			http.MethodPut,
		}

		for _, method := range authTests {
			t.Run(method, func(t *testing.T) {
				hubServerRequests = 0 // reset, clear out old data

				w := httptest.NewRecorder()
				req, err := http.NewRequest(method, sessionPath, nil)
				require.NoError(t, err)

				// Make the request to the handler
				assert.Equal(t, 0, hubServerRequests, "should have 0 requests before the request")
				handler.Handle(w, req)
				assert.Equal(t, 0, hubServerRequests, "should not have been forwarded to the server")

				resp := w.Result()
				assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)

				body, err := ioutil.ReadAll(resp.Body)
				require.NoError(t, err)
				assert.Equal(t, "Not authorized\n", string(body))
			})
		}
	})

	t.Run("forwards request if authorized", func (t *testing.T) {
		// Reset the free counts
		err = reg.SaveHub(&mockHub)
		require.NoError(t, err)

		hubServerRequests = 0
		req, w := createMockAllowedRequest(sessionPath, http.MethodPost, apiKeys)

		// Make the request to the handler
		assert.Equal(t, 0, hubServerRequests, "should have 0 requests before the request")
		handler.Handle(w, req)
		assert.Equal(t, 1, hubServerRequests, "should have been forwarded to the server")
	})

	t.Run("Accept-Encoding 'identity' is passed to hub server", func (t *testing.T) {
		// Reset the free counts
		err = reg.SaveHub(&mockHub)
		require.NoError(t, err)

		header = ""
		req, w := createMockAllowedRequest(sessionPath, http.MethodPost, apiKeys)

		assert.Empty(t, header)
		handler.Handle(w, req)
		assert.Equal(t, "identity", header)
	})
}

func TestHandler_IsAuthorized(t *testing.T) {
	allowedAPIKeys := []string{"testPassword1", "testPassword2"}

	// Set up App Config with a logger so that we can test it
	logger, hook := test.NewNullLogger()
	appConfig := config.NewMock()
	appConfig.Logger = logger

	handler := NewHandler(nil, allowedAPIKeys, appConfig)

	t.Run("When using a good password, it should return true", func (t *testing.T) {
		mockUser := "testUser"
		mockPassword := allowedAPIKeys[1]
		mockRequest := &http.Request{
			Header: map[string][]string{},
		}
		mockRequest.SetBasicAuth(mockUser, mockPassword)

		res := handler.IsAuthorized(mockRequest, allowedAPIKeys)
		assert.True(t, res)

		t.Run("and contains a log was sent for security", func (t *testing.T) {
			expectedLogLevel := logrus.InfoLevel
			expectedLogMessage := fmt.Sprintf("Authentication Succeeded: username %s", mockUser)
			assert.True(t, testContainsLogEntry(expectedLogLevel, expectedLogMessage, hook.AllEntries()),
				fmt.Sprintf("Expected to find log %s: %s.\nEntries: %v", expectedLogLevel, expectedLogMessage, hook.Entries))
			hook.Reset()
		})
	})

	t.Run("When using a bad password, it should return false", func (t *testing.T) {
		mockUser := "testUser"
		mockRequest := &http.Request{
			Header: map[string][]string{},
		}
		mockRequest.SetBasicAuth(mockUser, "notthepassword")
		res := handler.IsAuthorized(mockRequest, allowedAPIKeys)
		assert.False(t, res)

		t.Run("and contains a log was sent for security", func (t *testing.T) {
			expectedLogLevel := logrus.WarnLevel
			expectedLogMessage := fmt.Sprintf("Authentication Failed: username %s did not provide a valid auth key", mockUser)
			assert.True(t, testContainsLogEntry(expectedLogLevel, expectedLogMessage, hook.AllEntries()),
				fmt.Sprintf("Expected to find log %s: %s.\nEntries: %v", expectedLogLevel, expectedLogMessage, hook.Entries))
			hook.Reset()
		})
	})

	t.Run("When providing no basic auth header, should return false", func (t *testing.T) {
		mockRequest := &http.Request{
			Header: map[string][]string{},
		}
		res := handler.IsAuthorized(mockRequest, allowedAPIKeys)
		assert.False(t, res)

		t.Run("and contains a log was sent for security", func (t *testing.T) {
			expectedLogLevel := logrus.InfoLevel
			expectedLogMessage := "Authentication Failed: basic auth method returned false for 'ok'"
			assert.True(t, testContainsLogEntry(expectedLogLevel, expectedLogMessage, hook.AllEntries()),
				fmt.Sprintf("Expected to find log %s: %s.\nEntries: %v", expectedLogLevel, expectedLogMessage, hook.Entries))
			hook.Reset()
		})
	})
}

func testContainsLogEntry(expectedLevel logrus.Level, expectedMessage string, entries []*logrus.Entry) bool {
	for _, entry := range entries {
		if entry.Level == expectedLevel && entry.Message == expectedMessage {
			return true
		}
	}

	return false
}

func createMockAllowedRequest(path string, method string, apiKeys []string) (*http.Request, *httptest.ResponseRecorder) {
	url := "http://test.com"

	if len(path) > 0 {
		url += fmt.Sprintf("%s", path)
	}

	mockHttpReq, err := http.NewRequest(method, url, nil)
	if err != nil {
		panic(err)
	}

	if len(apiKeys) <= 0 {
		panic("api keys were less than 0")
	}

	mockHttpReq.SetBasicAuth("testUser", apiKeys[0])
	return mockHttpReq, httptest.NewRecorder()
}

func TestWriteMetricNewSession(t *testing.T) {
	t.Run("when a required piece is missing, returns an error", func (t *testing.T) {
		errMsg := "required request instrumentor or clock is missing"
		t.Run("req", func (t *testing.T) {
			res, err := WriteMetricNewSession(nil)
			assert.Nil(t, res)
			assert.EqualError(t, err, errMsg)
		})

		t.Run("appConfig", func (t *testing.T) {
			req := &Request{}
			res, err := WriteMetricNewSession(req)
			assert.Nil(t, res)
			assert.EqualError(t, err, errMsg)
		})

		t.Run("Instrumentor", func (t *testing.T) {
			req := &Request{
				AppConfig: &config.Config{
					Clock: clockwork.NewFakeClock(),
				},
			}
			res, err := WriteMetricNewSession(req)
			assert.Nil(t, res)
			assert.EqualError(t, err, errMsg)
		})

		t.Run("Clock", func (t *testing.T) {
			req := &Request{
				AppConfig: &config.Config{
					Instrumentor: &instrumentor.Instrumentor{},
				},
			}
			res, err := WriteMetricNewSession(req)
			assert.Nil(t, res)
			assert.EqualError(t, err, errMsg)
		})
	})

	t.Run("when all data provided", func (t *testing.T) {
		mockMetricWriter := &mocks.MetricWriter{}
		logger, _ := test.NewNullLogger()
		clock := clockwork.NewFakeClock()
		mockASGName := "1234"

		expectedInput := &cloudwatch.PutMetricDataInput{
			Namespace: aws.String("CBG"),
			MetricData: []*cloudwatch.MetricDatum{
				{
					MetricName: aws.String("NewSession"),
					Timestamp: aws.Time(clock.Now()),
					Value: aws.Float64(1.0),
					Unit: aws.String(cloudwatch.StandardUnitCount),
					Dimensions: []*cloudwatch.Dimension{
						{Name: aws.String("AutoScalingGroupName"), Value: aws.String(mockASGName)},
					},
				},
			},
		}

		req := &Request{
			AppConfig: &config.Config{
				Instrumentor: &instrumentor.Instrumentor{
					AutoScalingGroupName: mockASGName,
					MetricClient: mockMetricWriter,
				},
				Clock: clock,
				Logger: logger,
			},
		}
		t.Run("writes the metric when no error", func (t *testing.T) {
			mockMetricWriter.On("PutMetricData", expectedInput).Return(&cloudwatch.PutMetricDataOutput{}, nil).Once()
			res, err := WriteMetricNewSession(req)
			assert.NoError(t, err)
			assert.NotNil(t, res)
		})

		t.Run("returns error if cloudwatch error", func (t *testing.T) {
			mockErrMsg := "cwl mock error"
			mockMetricWriter.On("PutMetricData", expectedInput).Return(&cloudwatch.PutMetricDataOutput{}, errors.New(mockErrMsg)).Once()
			res, err := WriteMetricNewSession(req)
			assert.Empty(t, res)
			assert.EqualError(t, err, mockErrMsg)
		})
	})
}
