package grid_reboot

import (
	"code.justin.tv/qe/grid_reboot/pkg/config"
	grEC2 "code.justin.tv/qe/grid_reboot/pkg/ec2"
	"code.justin.tv/qe/grid_reboot/pkg/ec2/mocks"
	"errors"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/ec2"
	"github.com/bxcodec/faker"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
	"testing"
)

func TestGetHubs(t *testing.T) {
	// Set up a mock config
	mockConfig := config.GridRebootConfig{
		EC2: config.EC2{
			Hub:    config.Hub{Tags: map[string]string{
				"GridClusterID": "test",
			}},
			Region: "us-test-1",
		},
	}

	t.Run("returns the instances when no error", func (t *testing.T) {
		expectedInstances := createMockInstances(true, 1)
		// Mock the AWS SDK Response
		mockEC2Response := createMockDescribeInstanceOutput(expectedInstances)
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()

		ec2Handler := &grEC2.Handler{
			Service: mockEC2Service,
		}

		resp, err := GetHubs(mockConfig, ec2Handler)
		assert.NoError(t, err)
		require.NotNil(t, resp)
		assert.Equal(t, expectedInstances, resp)
	})

	t.Run("returns the error if one is returned", func (t *testing.T) {
		expectedErrorMsg := "testing error"
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(nil, errors.New(expectedErrorMsg)).Once()

		ec2Handler := &grEC2.Handler{
			Service: mockEC2Service,
		}

		resp, err := GetHubs(mockConfig, ec2Handler)
		assert.Error(t, err)
		assert.EqualError(t, err, expectedErrorMsg)
		assert.Empty(t, resp)
	})
}

func TestRebootNodesByCluster(t *testing.T) {
	mockConfig := &config.GridRebootConfig{}

	t.Run("does not reboot anything if no instances found", func (t *testing.T) {
		// Mock the AWS SDK Response
		mockEC2Response := createMockDescribeInstanceOutput(createMockInstances(false, 0))
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()

		ec2Handler := &grEC2.Handler{
			Service: mockEC2Service,
		}

		err := RebootNodesByCluster("testing", ec2Handler, mockConfig)
		assert.NoError(t, err)

		mockEC2Service.AssertNotCalled(t, "RebootInstances", mock.Anything)
	})

	t.Run("does not reboot instances that fail validation", func (t *testing.T) {
		// Mock the AWS SDK Response
		mockEC2Response := createMockDescribeInstanceOutput(createMockInstances(false, 1))
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()

		ec2Handler := &grEC2.Handler{
			Service: mockEC2Service,
		}

		err := RebootNodesByCluster("testing", ec2Handler, mockConfig)
		assert.Error(t, err)
		assert.EqualError(t, err, "invalid instance found, skipping the reboot")
		mockEC2Service.AssertNotCalled(t, "RebootInstances", mock.Anything)
	})

	t.Run("reboots instances that pass validation", func (t *testing.T) {
		// Mock the AWS SDK Response
		mockEC2Response := createMockDescribeInstanceOutput(createMockInstances(true, 1))
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()
		mockEC2Service.On("RebootInstances", mock.Anything).Return(&ec2.RebootInstancesOutput{}, nil)

		ec2Handler := &grEC2.Handler{
			Service: mockEC2Service,
		}

		err := RebootNodesByCluster("testing", ec2Handler, mockConfig)
		assert.NoError(t, err)
		mockEC2Service.AssertCalled(t, "RebootInstances", mock.Anything)
	})

	t.Run("returns error passed back by fetching instances", func (t *testing.T) {
		describeErrorMsg := "test error"
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(nil, errors.New(describeErrorMsg)).Once()

		ec2Handler := &grEC2.Handler{
			Service: mockEC2Service,
		}

		err := RebootNodesByCluster("testing", ec2Handler, mockConfig)
		assert.EqualError(t, err, describeErrorMsg)
		mockEC2Service.AssertNotCalled(t, "RebootInstances", mock.Anything)
	})

	t.Run("returns error passed back by rebooting instances", func (t *testing.T) {
		rebootErrMsg := "test error"

		// Mock the AWS SDK Response
		mockEC2Response := createMockDescribeInstanceOutput(createMockInstances(true, 1))
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()
		mockEC2Service.On("RebootInstances", mock.Anything).Return(&ec2.RebootInstancesOutput{}, errors.New(rebootErrMsg))

		ec2Handler := &grEC2.Handler{
			Service: mockEC2Service,
		}

		err := RebootNodesByCluster("testing", ec2Handler, mockConfig)
		assert.EqualError(t, err, rebootErrMsg)
		mockEC2Service.AssertCalled(t, "RebootInstances", mock.Anything)
	})
}

// Helper method that creates a Mock DescribeInstanceOutput to be returned by AWS EC2 SDK
// Provide the instances to be returned in that describe output
func createMockDescribeInstanceOutput(instances []*ec2.Instance) *ec2.DescribeInstancesOutput {
	// Mock the AWS SDK Response
	return &ec2.DescribeInstancesOutput{
		Reservations: []*ec2.Reservation{
			{ Instances: instances},
		},
	}
}

// Helper method that creates AWS Instances
// passValidation is if it should pass the Grid Reboot validation method
// numberOfInstances is how many instances to be created
func createMockInstances(passValidation bool, numberOfInstances int) []*ec2.Instance {
	var instances []*ec2.Instance

	for i := 1; i <= numberOfInstances; i++ {
		instance := &ec2.Instance{}
		err := faker.FakeData(&instance)
		if err != nil { panic(err) }

		if passValidation {
			instance.Tags = []*ec2.Tag {
				{ Key: aws.String("Service"), Value: aws.String("grid")},
			}
		}

		instances = append(instances, instance)
	}

	return instances
}
