package session

import (
	"testing"

	"github.com/aws/aws-sdk-go/service/sts"

	"github.com/aws/aws-sdk-go/aws/endpoints"

	"code.justin.tv/eventbus/client/internal/stsregion"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/stretchr/testify/assert"
)

func TestSessionOverride(t *testing.T) {
	region := "us-west-2"
	sess := session.Must(session.NewSession(&aws.Config{
		Region: aws.String(region),
	}))

	overrideSess, err := Corrected(sess)
	assert.NoError(t, err)

	stsEndpointOriginal := stsregion.ConfiguredEndpoint(sess)
	assert.Equal(t, stsEndpointOriginal, endpoints.LegacySTSEndpoint)

	stsEndpointCorrected := overrideSess.ClientConfig(sts.ServiceName).Config.STSRegionalEndpoint
	overriddenRegion := aws.StringValue(overrideSess.ClientConfig(sts.ServiceName).Config.Region)

	assert.Equal(t, endpoints.RegionalSTSEndpoint, stsEndpointCorrected)
	assert.Equal(t, region, overriddenRegion)
}
