package main

import (
	gridRebootMocks "code.justin.tv/qe/grid_reboot/mocks"
	"code.justin.tv/qe/grid_reboot/pkg/config"
	grEC2 "code.justin.tv/qe/grid_reboot/pkg/ec2"
	"code.justin.tv/qe/grid_reboot/pkg/ec2/mocks"
	"code.justin.tv/qe/grid_router/src/pkg/hub_registry"
	"errors"
	"fmt"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/ec2"
	"github.com/aws/aws-sdk-go/service/elasticbeanstalk"
	"github.com/jonboulle/clockwork"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"testing"
	"time"
)

func TestFetchClusterIDs(t *testing.T) {
	// Set up a mock config
	mockConfig := config.GridRebootConfig{
		EC2: config.EC2{
			Hub:    config.Hub{Tags: map[string]string{
				"GridClusterID": "test",
			}},
			Region: "us-test-1",
		},
	}

	t.Run("returns the proper cluster ids", func (t *testing.T) {
		clusterName1 := "cluster_1"
		clusterName2 := "cluster_2"

		// Set up instances to return from AWS SDK
		instance1 := &ec2.Instance{
			InstanceId: aws.String("i-test1"),
			Tags: []*ec2.Tag {
				{ Key: aws.String("Service"), Value: aws.String("grid")},
				{ Key: aws.String("GridClusterID"), Value: aws.String(clusterName1)},
			},
		}
		instance2 := &ec2.Instance{
			InstanceId: aws.String("i-test2"),
			Tags: []*ec2.Tag {
				{ Key: aws.String("Service"), Value: aws.String("grid")},
				{ Key: aws.String("GridClusterID"), Value: aws.String(clusterName2)},
			},
		}
		expectedInstances := []*ec2.Instance{
			instance1, instance2,
		}

		// Mock the AWS SDK Response
		mockEC2Response := &ec2.DescribeInstancesOutput{
			Reservations: []*ec2.Reservation{
				{ Instances: expectedInstances },
			},
		}
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()

		ec2Handler := &grEC2.Handler{
			Service: mockEC2Service,
		}

		resp, err := FetchClusterIDs(mockConfig, ec2Handler)
		assert.NoError(t, err)
		assert.Contains(t, resp, clusterName1)
		assert.Contains(t, resp, clusterName2)
	})

	t.Run("returns an error if there was a problem getting instances", func (t *testing.T) {
		describeErrorMsg := "test error describing"
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(nil, errors.New(describeErrorMsg)).Once()

		ec2Handler := &grEC2.Handler{
			Service: mockEC2Service,
		}

		resp, err := FetchClusterIDs(mockConfig, ec2Handler)
		assert.EqualError(t, err, describeErrorMsg)
		assert.Empty(t, resp)
	})
}

func TestProcessCluster(t *testing.T) {
	mockConfig := config.GridRebootConfig{}

	// Add a mock clock
	mockClock := clockwork.NewFakeClock()
	mockConfig.Clock = mockClock

	mockConfig.EC2.Hub.Tags = map[string]string{
		"Environment": "mock",
	}

	bsInterface := &gridRebootMocks.BeanstalkInterface{}
	bsInterface.On("RestartAppServer", mock.Anything).Return(&elasticbeanstalk.RestartAppServerOutput{}, nil)
	bsService := &BeanstalkService{
		Service: bsInterface,
	}

	t.Run("pauses, reboots and then un-pauses in normal conditions", func (t *testing.T) {
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  1,
				Total: 1,
			},
			Paused: false,
		}

		// Same hub as above, but paused. This will be used after the code has paused the hub
		mockGetHubResponsePaused := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  1,
				Total: 1,
			},
			Paused: true,
		}

		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGridRouterClient.On("PauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()
		mockGridRouterClient.On("UnpauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()

		// On first call, return the hub showing it unpaused
		// Then for future calls, show it as paused
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil).Once()
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponsePaused, nil)

		instance1 := &ec2.Instance{
			InstanceId: aws.String("i-test"),
			Tags: []*ec2.Tag {
				{ Key: aws.String("Service"), Value: aws.String("grid")},
			},
		}
		expectedInstances := []*ec2.Instance{
			instance1,
		}

		// Mock the AWS SDK Response
		mockEC2Response := &ec2.DescribeInstancesOutput{
			Reservations: []*ec2.Reservation{
				{ Instances: expectedInstances },
			},
		}
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()
		mockEC2Service.On("RebootInstances", mock.Anything).Return(&ec2.RebootInstancesOutput{}, nil).Once()

		mockEC2Handler := &grEC2.Handler{
			Service:  mockEC2Service,
		}

		// Simulate time passing while rebooting
		go func() {
			time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
			mockClock.Advance(hubRebootSleepTime)
			time.Sleep(time.Millisecond * 100)
			mockClock.Advance(rebootSleepTime)
		}()

		err := ProcessCluster(mockConfig, mockGridRouterClient, mockEC2Handler, mockGetHubResponse.ClusterName, bsService)
		assert.NoError(t, err)

		// Should have paused once
		mockGridRouterClient.AssertNumberOfCalls(t, "PauseHubByClusterName", 1)
		// Should have unpaused once
		mockGridRouterClient.AssertNumberOfCalls(t, "UnpauseHubByClusterName", 1)
		// Should have rebooted once
		mockEC2Service.AssertNumberOfCalls(t, "RebootInstances", 1)
		// Should have rebooted the Hub
		bsInterface.AssertNumberOfCalls(t, "RestartAppServer", 1)
	})

	t.Run("returns error getting hub and does not pause/unpause", func (t *testing.T) {
		clusterName := "test-cluster"
		errMsg := "test error getting"
		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGridRouterClient.On("GetHubByClusterName", mock.Anything).Return(nil, errors.New(errMsg)).Once()

		err := ProcessCluster(mockConfig, mockGridRouterClient, &grEC2.Handler{}, clusterName, bsService)
		assert.EqualError(t, err, errMsg)
		mockGridRouterClient.AssertNotCalled(t, "PauseHubByClusterName", clusterName)
		mockGridRouterClient.AssertNotCalled(t, "UnpauseHubByClusterName", clusterName)
	})

	t.Run("issues unpause if there were errors", func (t *testing.T) {
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  1,
				Total: 1,
			},
			Paused: false,
		}

		t.Run("pausing", func (t *testing.T) {
			errMsg := "test error pausing"
			mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
			mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil).Once()
			mockGridRouterClient.On("PauseHubByClusterName", mockGetHubResponse.ClusterName).Return(errors.New(errMsg)).Once()
			mockGridRouterClient.On("UnpauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()

			err := ProcessCluster(mockConfig, mockGridRouterClient, &grEC2.Handler{}, mockGetHubResponse.ClusterName, bsService)
			assert.EqualError(t, err, errMsg)
			mockGridRouterClient.AssertCalled(t, "UnpauseHubByClusterName", mockGetHubResponse.ClusterName)
		})

		t.Run("draining", func (t *testing.T) {
			mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
			mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil).Once()
			mockGridRouterClient.On("PauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()
			mockGridRouterClient.On("UnpauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()

			// return a non-drained cluster
			mockGetHubResponse.SlotCounts.Free  = 0
			mockGetHubResponse.SlotCounts.Total = 1
			mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil) // not once, may call multiple times

			// Simulate time passing while checking for drain
			go func() {
				time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
				mockClock.Advance(time.Minute * 15)
			}()

			err := ProcessCluster(mockConfig, mockGridRouterClient, &grEC2.Handler{}, mockGetHubResponse.ClusterName, bsService)
			assert.EqualError(t, err, fmt.Sprintf("timeout reached waiting for cluster [%s] to drain", mockGetHubResponse.ClusterName))
			mockGridRouterClient.AssertCalled(t, "UnpauseHubByClusterName", mockGetHubResponse.ClusterName)
		})

		t.Run("rebooting", func (t *testing.T) {
			rebootErrMsg := "test error rebooting"

			mockGetHubResponse := &hub_registry.Hub{
				ClusterName: "test-cluster",
				SlotCounts: hub_registry.SlotCounts{
					Free:  1,
					Total: 1,
				},
				Paused: false,
			}

			mockGetHubResponsePaused := &hub_registry.Hub{
				ClusterName: "test-cluster",
				SlotCounts: hub_registry.SlotCounts{
					Free:  1,
					Total: 1,
				},
				Paused: true,
			}

			mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
			mockGridRouterClient.On("PauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()
			mockGridRouterClient.On("UnpauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()

			// On first call, return the hub showing it unpaused
			// Then for future calls, show it as paused
			mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil).Once()
			mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponsePaused, nil)


			instance1 := &ec2.Instance{
				InstanceId: aws.String("i-test"),
				Tags: []*ec2.Tag {
					{ Key: aws.String("Service"), Value: aws.String("grid")},
				},
			}
			expectedInstances := []*ec2.Instance{
				instance1,
			}

			// Mock the AWS SDK Response
			mockEC2Response := &ec2.DescribeInstancesOutput{
				Reservations: []*ec2.Reservation{
					{ Instances: expectedInstances },
				},
			}
			mockEC2Service := &mocks.Service{}
			mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()
			mockEC2Service.On("RebootInstances", mock.Anything).Return(nil, errors.New(rebootErrMsg)).Once()

			mockEC2Handler := &grEC2.Handler{
				Service:  mockEC2Service,
			}

			// Simulate time passing while rebooting
			go func() {
				time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
				mockClock.Advance(hubRebootSleepTime)
				time.Sleep(time.Millisecond * 100)
				mockClock.Advance(rebootSleepTime)
			}()

			err := ProcessCluster(mockConfig, mockGridRouterClient, mockEC2Handler, mockGetHubResponse.ClusterName, bsService)
			assert.Error(t, err)

			// Should have paused once
			mockGridRouterClient.AssertNumberOfCalls(t, "PauseHubByClusterName", 1)
			// Should have unpaused once
			mockGridRouterClient.AssertNumberOfCalls(t, "UnpauseHubByClusterName", 1)
			// Should have attempted to reboot once
			mockEC2Service.AssertNumberOfCalls(t, "RebootInstances", 1)
		})
	})

	t.Run("returns error if final unpause fails", func (t *testing.T) {
		unpauseErrMsg := "unpause error test"

		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  1,
				Total: 1,
			},
			Paused: false,
		}

		mockGetHubResponsePaused := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  1,
				Total: 1,
			},
			Paused: true,
		}

		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGridRouterClient.On("PauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()
		mockGridRouterClient.On("UnpauseHubByClusterName", mockGetHubResponse.ClusterName).Return(errors.New(unpauseErrMsg)).Once()

		// On first call, return the hub showing it unpaused
		// Then for future calls, show it as paused
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil).Once()
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponsePaused, nil)

		instance1 := &ec2.Instance{
			InstanceId: aws.String("i-test"),
			Tags: []*ec2.Tag {
				{ Key: aws.String("Service"), Value: aws.String("grid")},
			},
		}
		expectedInstances := []*ec2.Instance{
			instance1,
		}

		// Mock the AWS SDK Response
		mockEC2Response := &ec2.DescribeInstancesOutput{
			Reservations: []*ec2.Reservation{
				{ Instances: expectedInstances },
			},
		}
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()
		mockEC2Service.On("RebootInstances", mock.Anything).Return(&ec2.RebootInstancesOutput{}, nil).Once()

		mockEC2Handler := &grEC2.Handler{
			Service:  mockEC2Service,
		}

		// Simulate time passing while rebooting
		go func() {
			time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
			mockClock.Advance(hubRebootSleepTime)
			time.Sleep(time.Millisecond * 100)
			mockClock.Advance(rebootSleepTime)
		}()

		err := ProcessCluster(mockConfig, mockGridRouterClient, mockEC2Handler, mockGetHubResponse.ClusterName, bsService)
		assert.EqualError(t, err, "unpause error test")
	})

	t.Run("skips a cluster if it's already paused", func (t *testing.T) { // don't want to unpause an intentionally paused cluster
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  1,
				Total: 1,
			},
			Paused: true,
		}

		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil)
		mockEC2Service := &mocks.Service{}

		mockEC2Handler := &grEC2.Handler{
			Service:  mockEC2Service,
		}

		err := ProcessCluster(mockConfig, mockGridRouterClient, mockEC2Handler, mockGetHubResponse.ClusterName, bsService)
		assert.EqualError(t, err, "hub was already paused")

		mockGridRouterClient.AssertNotCalled(t, "PauseHubByClusterName")
		mockGridRouterClient.AssertNotCalled(t, "UnpauseHubByClusterName")
		mockEC2Service.AssertNotCalled(t, "RebootInstances")
	})

	t.Run("returns an error if Hub Environment tag missing", func (t *testing.T) {
		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockEC2Handler := &grEC2.Handler{
			Service:  &mocks.Service{},
		}

		// Create a config with no tag
		configNoTag := config.GridRebootConfig{}

		err := ProcessCluster(configNoTag, mockGridRouterClient, mockEC2Handler, "mockCluster", bsService)
		assert.EqualError(t, err, "unknown hub environment within config tags")
	})

	t.Run("returns an error if Hub Reboot fails", func (t *testing.T) {
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  1,
				Total: 1,
			},
			Paused: false,
		}

		mockGetHubResponsePaused := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  1,
				Total: 1,
			},
			Paused: true,
		}

		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGridRouterClient.On("PauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()
		mockGridRouterClient.On("UnpauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()
		// On first call, return the hub showing it unpaused
		// Then for future calls, show it as paused
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil).Once()
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponsePaused, nil)

		instance1 := &ec2.Instance{
			InstanceId: aws.String("i-test"),
			Tags: []*ec2.Tag {
				{ Key: aws.String("Service"), Value: aws.String("grid")},
			},
		}
		expectedInstances := []*ec2.Instance{
			instance1,
		}

		// Mock the AWS SDK Response
		mockEC2Response := &ec2.DescribeInstancesOutput{
			Reservations: []*ec2.Reservation{
				{ Instances: expectedInstances },
			},
		}
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()
		mockEC2Service.On("RebootInstances", mock.Anything).Return(&ec2.RebootInstancesOutput{}, nil).Once()

		mockEC2Handler := &grEC2.Handler{
			Service:  mockEC2Service,
		}

		// Simulate time passing while rebooting
		go func() {
			time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
			mockClock.Advance(hubRebootSleepTime)
			time.Sleep(time.Millisecond * 100)
			mockClock.Advance(rebootSleepTime)
		}()

		expectedErr := "mock error rebooting"
		bsInterfaceErr := &gridRebootMocks.BeanstalkInterface{}
		bsInterfaceErr.On("RestartAppServer", mock.Anything).Return( // Return an error
			&elasticbeanstalk.RestartAppServerOutput{},errors.New(expectedErr))
		bsServiceErr := &BeanstalkService{
			Service: bsInterfaceErr,
		}

		err := ProcessCluster(mockConfig, mockGridRouterClient, mockEC2Handler, mockGetHubResponse.ClusterName, bsServiceErr)
		assert.EqualError(t, err, expectedErr)

		// Should have unpaused
		mockGridRouterClient.AssertNumberOfCalls(t, "UnpauseHubByClusterName", 1)
	})

	t.Run("returns error if the hub becomes unpaused prior to reboot", func (t *testing.T) {
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  0,
				Total: 1,
			},
			Paused: false,
		}

		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGridRouterClient.On("PauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()
		mockGridRouterClient.On("UnpauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil) // not once, may call multiple times

		instance1 := &ec2.Instance{
			InstanceId: aws.String("i-test"),
			Tags: []*ec2.Tag {
				{ Key: aws.String("Service"), Value: aws.String("grid")},
			},
		}
		expectedInstances := []*ec2.Instance{
			instance1,
		}

		// Mock the AWS SDK Response
		mockEC2Response := &ec2.DescribeInstancesOutput{
			Reservations: []*ec2.Reservation{
				{ Instances: expectedInstances },
			},
		}
		mockEC2Service := &mocks.Service{}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil).Once()
		mockEC2Service.On("RebootInstances", mock.Anything).Return(&ec2.RebootInstancesOutput{}, nil).Once()

		mockEC2Handler := &grEC2.Handler{
			Service:  mockEC2Service,
		}

		// Simulate time passing while waiting for drain
		go func() {
			time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...

			// Change the state to unpaused, and also show completely drained
			mockGetHubResponse.Paused = false
			mockGetHubResponse.SlotCounts.Free = mockGetHubResponse.SlotCounts.Total
			mockClock.Advance(time.Minute * 2)
		}()

		err := ProcessCluster(mockConfig, mockGridRouterClient, mockEC2Handler, mockGetHubResponse.ClusterName, bsService)
		assert.EqualError(t, err, "hub was not paused when it should have been")

		// Should have paused once
		mockGridRouterClient.AssertNumberOfCalls(t, "PauseHubByClusterName", 1)
		// Should have still tried unpausing as a precaution
		mockGridRouterClient.AssertNumberOfCalls(t, "UnpauseHubByClusterName", 1)
		// Should not have rebooted
		mockEC2Service.AssertNumberOfCalls(t, "RebootInstances", 0)
	})
}

func TestProcessClusters(t *testing.T) {
	clusterIDs := []string { "test-cluster-1", "test-cluster-2" }

	mockConfig := config.GridRebootConfig{}
	mockConfig.EC2.Hub.Tags = map[string]string{
		"Environment": "mock",
	}

	mockEC2Service := &mocks.Service{}

	// Add a mock clock
	mockClock := clockwork.NewFakeClock()
	mockConfig.Clock = mockClock

	bsInterface := &gridRebootMocks.BeanstalkInterface{}
	bsInterface.On("RestartAppServer", mock.Anything).Return(&elasticbeanstalk.RestartAppServerOutput{}, nil)
	bsService := &BeanstalkService{
		Service: bsInterface,
	}

	t.Run("runs commands on all clusters", func (t *testing.T) {
		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}

		for _, cluster := range clusterIDs {
			mockGetHubResponse := &hub_registry.Hub{
				ClusterName: cluster,
				SlotCounts: hub_registry.SlotCounts{
					Free:  1,
					Total: 1,
				},
			}

			mockGetHubResponsePaused := &hub_registry.Hub{
				ClusterName: "test-cluster",
				SlotCounts: hub_registry.SlotCounts{
					Free:  1,
					Total: 1,
				},
				Paused: true,
			}

			mockGridRouterClient.On("PauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()
			mockGridRouterClient.On("UnpauseHubByClusterName", mockGetHubResponse.ClusterName).Return(nil).Once()

			// On first call, return the hub showing it unpaused
			// Then for future calls, show it as paused
			mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil).Once()
			mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponsePaused, nil)
		}

		instance1 := &ec2.Instance{
			InstanceId: aws.String("i-test"),
			Tags: []*ec2.Tag {
				{ Key: aws.String("Service"), Value: aws.String("grid")},
			},
		}
		expectedInstances := []*ec2.Instance{
			instance1,
		}

		// Mock the AWS SDK Response
		mockEC2Response := &ec2.DescribeInstancesOutput{
			Reservations: []*ec2.Reservation{
				{ Instances: expectedInstances },
			},
		}
		mockEC2Service.On("DescribeInstances", mock.Anything).Return(mockEC2Response, nil)
		mockEC2Service.On("RebootInstances", mock.Anything).Return(&ec2.RebootInstancesOutput{}, nil)

		mockEC2Handler := &grEC2.Handler{
			Service:  mockEC2Service,
		}

		// Simulate time passing while rebooting
		go func() {
			for range clusterIDs {
				time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
				mockClock.Advance(hubRebootSleepTime)
				time.Sleep(time.Millisecond * 100)
				mockClock.Advance(rebootSleepTime)
			}
		}()

		err := ProcessClusters(mockConfig, mockGridRouterClient, mockEC2Handler, clusterIDs, bsService)
		assert.NoError(t, err)

		// Should have paused twice (two clusters)
		mockGridRouterClient.AssertNumberOfCalls(t, "PauseHubByClusterName", 2)
		// Should have unpaused twice (two clusters)
		mockGridRouterClient.AssertNumberOfCalls(t, "UnpauseHubByClusterName", 2)
		// Should have rebooted twice (two clusters)
		mockEC2Service.AssertNumberOfCalls(t, "RebootInstances", 2)
	})

	t.Run("returns an error from processing the cluster and does not proceed", func (t *testing.T) {
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			Paused: false,
		}

		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		pauseErrMsg := "test error pause"

		mockGridRouterClient.On("GetHubByClusterName", mock.Anything).Return(mockGetHubResponse, nil)
		mockGridRouterClient.On("PauseHubByClusterName", mock.Anything).Return(errors.New(pauseErrMsg)).Once()
		mockGridRouterClient.On("UnpauseHubByClusterName", mock.Anything).Return(nil).Once()

		mockEC2Handler := &grEC2.Handler{
			Service:  mockEC2Service,
		}

		err := ProcessClusters(mockConfig, mockGridRouterClient, mockEC2Handler, clusterIDs, bsService)
		assert.EqualError(t, err, pauseErrMsg)

		// Should have paused once
		mockGridRouterClient.AssertNumberOfCalls(t, "PauseHubByClusterName", 1)
		// Should have unpaused once
		mockGridRouterClient.AssertNumberOfCalls(t, "UnpauseHubByClusterName", 1)
		// Should not have rebooted
		mockEC2Service.AssertNotCalled(t, "RebootInstances")
	})

	// test to ensure if the hub is already paused, we just skip over it
	// this is because a hub may be paused for maintenance or other reasons
	// we want to move forward rebooting other clusters
	t.Run("ignores hub already paused error", func (t *testing.T) {
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			Paused: true,
		}

		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGridRouterClient.On("GetHubByClusterName", mock.Anything).Return(mockGetHubResponse, nil)
		mockEC2Handler := &grEC2.Handler{
			Service:  mockEC2Service,
		}

		err := ProcessClusters(mockConfig, mockGridRouterClient, mockEC2Handler, clusterIDs, bsService)
		assert.NoError(t, err)
		mockEC2Service.AssertNotCalled(t, "PauseHubByClusterName")
		mockEC2Service.AssertNotCalled(t, "UnpauseHubByClusterName")
		mockEC2Service.AssertNotCalled(t, "RebootInstances")
	})
}

func TestWaitForDrain(t *testing.T) {
	mockConfig := config.GridRebootConfig{}

	// Add a mock clock
	mockClock := clockwork.NewFakeClock()
	mockConfig.Clock = mockClock

	t.Run("returns immediately when cluster is already drained", func (t *testing.T) {
		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  1,
				Total: 1,
			},
		}
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil).Once()

		err := WaitForDrain(mockConfig, mockGridRouterClient, mockGetHubResponse.ClusterName)
		assert.NoError(t, err)
		mockGridRouterClient.AssertNumberOfCalls(t, "GetHubByClusterName", 1) // should not have retried
	})

	t.Run("returns if a hub is drained within the time limit", func (t *testing.T) {
		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  0,
				Total: 1,
			},
		}
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil)

		go func() {
			time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
			mockClock.Advance(time.Minute)
			mockGetHubResponse.SlotCounts.Free = mockGetHubResponse.SlotCounts.Total
		}()

		err := WaitForDrain(mockConfig, mockGridRouterClient, mockGetHubResponse.ClusterName)
		assert.NoError(t, err)
		assert.True(t, len(mockGridRouterClient.Calls) > 1) // should have retried
	})

	t.Run("returns error if hub does not drain within time limit", func (t *testing.T) {
		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  0,
				Total: 1,
			},
		}
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil)

		go func() {
			// Simulate time passing, give retries enough time to happen
			time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
			mockClock.Advance(time.Minute * 10)
			time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
			mockClock.Advance(time.Minute * 5)
		}()

		err := WaitForDrain(mockConfig, mockGridRouterClient, mockGetHubResponse.ClusterName)
		assert.EqualError(t, err, fmt.Sprintf("timeout reached waiting for cluster [%s] to drain", mockGetHubResponse.ClusterName))
		assert.True(t, len(mockGridRouterClient.Calls) > 1) // should have retried
	})

	t.Run("retries if an error is thrown getting hubs", func (t *testing.T) {
		getHubErrorMsg := "test error getting hub"
		mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
		mockGetHubResponse := &hub_registry.Hub{
			ClusterName: "test-cluster",
			SlotCounts: hub_registry.SlotCounts{
				Free:  0,
				Total: 1,
			},
		}

		// first try = error
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(nil, errors.New(getHubErrorMsg)).Once()
		go func() {
			// Simulate time passing - second try = response, but not drained
			time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
			mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil)
			mockClock.Advance(time.Second * 5)

			// Third try = response, drained
			time.Sleep(time.Millisecond * 100) // CI seems to be unreliable on its sleep times, so this is a little long...
			mockGetHubResponse.SlotCounts.Free  = 1
			mockGetHubResponse.SlotCounts.Total = 1
			mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil)
			mockClock.Advance(time.Second * 5)
		}()

		err := WaitForDrain(mockConfig, mockGridRouterClient, mockGetHubResponse.ClusterName)
		assert.NoError(t, err)
		mockGridRouterClient.AssertNumberOfCalls(t, "GetHubByClusterName", 3)
	})
}

func TestIsHubPaused(t *testing.T) {
	mockGridRouterClient := &gridRebootMocks.GridRouterClient{}
	mockGetHubResponse := &hub_registry.Hub{
		ClusterName: "test-cluster",
		SlotCounts: hub_registry.SlotCounts{
			Free:  0,
			Total: 1,
		},
	}

	t.Run("on successful response", func (t *testing.T) {
		states := []struct {
			in bool
			out bool
		}{
			{ false, false },
			{ true, true },
		}

		for _, states := range states {
			t.Run(fmt.Sprintf("%v", states.in), func (t *testing.T) {
				mockGetHubResponse.Paused = states.in
				mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, nil).Once()
				resp, err := IsHubPaused(mockGridRouterClient, mockGetHubResponse.ClusterName)
				assert.NoError(t, err)
				assert.Equal(t, states.out, resp)
			})
		}
	})

	t.Run("on error, returns the error and false response", func (t *testing.T) {
		mockErr := "test error"
		mockGridRouterClient.On("GetHubByClusterName", mockGetHubResponse.ClusterName).Return(mockGetHubResponse, errors.New(mockErr)).Once()
		resp, err := IsHubPaused(mockGridRouterClient, mockGetHubResponse.ClusterName)
		assert.EqualError(t, err, mockErr)
		assert.Equal(t, false, resp)
	})
}
