package publisher

import (
	"context"
	"net/http"
	"net/http/httptest"
	"strconv"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go/aws/client"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/sns"
	"github.com/aws/aws-sdk-go/service/sns/snsiface"
	"github.com/aws/aws-sdk-go/service/sts"
	"github.com/aws/aws-sdk-go/service/sts/stsiface"
	"github.com/golang/protobuf/ptypes/timestamp"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	eventbus "code.justin.tv/eventbus/client"
	"code.justin.tv/eventbus/client/internal/testschema/clock"
	"code.justin.tv/eventbus/client/internal/testschema/user_password"
	"code.justin.tv/eventbus/client/lowlevel/snsclient"
)

type mockSTSClient struct {
	stsiface.STSAPI
	calls int
}

func (m *mockSTSClient) GetCallerIdentity(input *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) {
	m.calls++
	return &sts.GetCallerIdentityOutput{
		Account: aws.String("123456789012"),
		Arn:     aws.String("arn:aws:iam::123456789012:resource"),
		UserId:  aws.String("123abc"),
	}, nil
}

type mockCFNVersionChecker struct {
	retBool bool
	retErr  error
}

func (m *mockCFNVersionChecker) shouldReturn(b bool, err error) {
	m.retBool = b
	m.retErr = err
}

func (m *mockCFNVersionChecker) Require(ctx context.Context, sess client.ConfigProvider, version string) (bool, error) {
	return m.retBool, m.retErr
}

func TestPublisherBasic(t *testing.T) {
	mockSTS := &mockSTSClient{}
	mockVersion := &mockCFNVersionChecker{}
	t.Run("Constructor Validation", func(t *testing.T) {
		returnStatus := http.StatusOK
		session, server := newMockSession(func(w http.ResponseWriter, r *http.Request) {
			w.WriteHeader(returnStatus)
		})
		defer server.Close()

		t.Run("Fails on invalid message type", func(t *testing.T) {
			p, err := newWithAWSClients(Config{
				Environment: EnvLocal,
				Session:     session,
				EventTypes:  []string{"!Unknown"},
			}, mockSTS, mockVersion)
			assert.Nil(t, p)
			require.Error(t, err)
			assert.Equal(t, "event type !Unknown is invalid, must be non-empty and alphanumeric", err.Error())
		})

		t.Run("Gets route but fails noop", func(t *testing.T) {
			returnStatus = http.StatusInternalServerError
			p, err := newWithAWSClients(Config{
				Environment: EnvLocal,
				Session:     session,
				EventTypes:  []string{"ClockTick"},
			}, mockSTS, mockVersion)
			require.Error(t, err)
			require.Nil(t, p)
			require.Contains(t, err.Error(), "publish noop to SNS")
		})

		t.Run("HappyPath", func(t *testing.T) {
			handler := func(w http.ResponseWriter, r *http.Request) {
				w.WriteHeader(http.StatusOK)
			}
			session, server := newMockSession(handler)
			defer server.Close()
			t.Run("Performs assume role logic for low CFN versions", func(t *testing.T) {
				mockSTS := &mockSTSClient{}
				mockVersion := &mockCFNVersionChecker{}
				p, err := newWithAWSClients(Config{
					Environment: EnvStaging,
					Session:     session,
					EventTypes:  []string{}, // dont declare any events so no noops are attempted!
				}, mockSTS, mockVersion)
				require.NoError(t, err)
				require.NotNil(t, p)
				assert.Equal(t, 1, mockSTS.calls)
			})
			t.Run("Does not perform assume role logic when using new CFN versions", func(t *testing.T) {
				mockSTS := &mockSTSClient{}
				mockVersion := &mockCFNVersionChecker{}
				mockVersion.shouldReturn(true, nil)
				p, err := newWithAWSClients(Config{
					Environment: EnvDevelopment,
					Session:     session,
					EventTypes:  []string{}, // dont declare any events so no noops are attempted!
				}, mockSTS, mockVersion)
				require.NoError(t, err)
				require.NotNil(t, p)
				assert.Equal(t, 0, mockSTS.calls)
			})
		})

	})
}

func TestPublisherMockedServer(t *testing.T) {
	mockSTS := &mockSTSClient{}
	mockVersion := &mockCFNVersionChecker{}
	var calls []call
	handler := func(w http.ResponseWriter, r *http.Request) {
		calls = append(calls, call{
			Method: r.Method,
			Path:   r.URL.Path,
		})
		w.WriteHeader(http.StatusOK)
	}
	session, server := newMockSession(handler)
	defer server.Close()

	p, err := newWithAWSClients(Config{
		Environment: EnvLocal,
		Session:     session,
		EventTypes:  []string{"ClockUpdate"},
	}, mockSTS, mockVersion)
	require.NoError(t, err)

	reset := func() {
		calls = nil
	}

	t.Run("GotNoop", func(t *testing.T) {
		assert.Len(t, calls, 1)
		assert.Equal(t, "POST", calls[0].Method)
		reset()
	})

	t.Run("PublishUndeclaredType", func(t *testing.T) {
		reset()
		err := p.Publish(context.Background(), &user_password.UserPasswordUpdate{})
		assert.Equal(t, errUnregistered, err)
		assert.Len(t, calls, 0)
	})

	t.Run("PublishClockUpdate", func(t *testing.T) {
		reset()
		err := p.Publish(context.Background(), &clock.ClockUpdate{Time: &timestamp.Timestamp{Seconds: 1234}})
		assert.NoError(t, err)
		require.Len(t, calls, 1)
		assert.Equal(t, "POST", calls[0].Method)
	})
}

func TestPublisherContextTimeouts(t *testing.T) {
	base, cancel := context.WithCancel(context.Background())

	var sleepFor time.Duration
	handler := func(w http.ResponseWriter, r *http.Request) {
		select {
		case <-time.After(sleepFor):
		case <-base.Done():
		case <-r.Context().Done():
		}
		w.WriteHeader(http.StatusOK)
	}
	session, server := newMockSession(handler)
	defer func() {
		cancel()
		server.Close()
	}()

	mockSTS := &mockSTSClient{}
	mockVersion := &mockCFNVersionChecker{}
	p, err := newWithAWSClients(Config{
		Environment: EnvLocal,
		Session:     session,
		EventTypes:  []string{"ClockUpdate"},
	}, mockSTS, mockVersion)
	require.NoError(t, err)

	t.Run("Context timeout works", func(t *testing.T) {
		sleepFor = 1 * time.Second
		ctx, cancel := context.WithTimeout(base, 50*time.Millisecond)
		defer cancel()
		begin := time.Now()
		err := p.Publish(ctx, &clock.ClockUpdate{Time: &timestamp.Timestamp{Seconds: 1234}})
		assert.Error(t, err)
		assert.Contains(t, err.Error(), "canceled")
		// Leave a lot of leeway for slow/busy machines, but still less than the 1sec timeout.
		assert.True(t, time.Since(begin) < 700*time.Millisecond)
	})

	t.Run("Successful sleep", func(t *testing.T) {
		sleepFor = 50 * time.Millisecond
		ctx, cancel := context.WithTimeout(base, 250*time.Millisecond)
		defer cancel()
		begin := time.Now()
		err := p.Publish(ctx, &clock.ClockUpdate{Time: &timestamp.Timestamp{Seconds: 1234}})
		assert.NoError(t, err)
		assert.True(t, time.Since(begin) >= 50*time.Millisecond)
	})
}

type call struct {
	Method string
	Path   string
}

func newMockSession(handler http.HandlerFunc) (*session.Session, *httptest.Server) {
	// server is the mock server
	server := httptest.NewServer(handler)

	s := session.Must(session.NewSession(&aws.Config{
		Region:      aws.String("us-east-1"), // region is required for request signing to work
		DisableSSL:  aws.Bool(true),
		Endpoint:    aws.String(server.URL),
		Credentials: credentials.NewStaticCredentials("AKIAFAKE", "abcfake", "token"),
	}))
	return s, server
}

func TestPublisherInterfaceConversion(t *testing.T) {
	// check that the old interface can convert to the new interface.
	p := &Publisher{}
	var pi eventbus.Publisher = p
	require.NotNil(t, pi)
}

func TestPublisherMockedSNSAPI(t *testing.T) {
	mockSNS := &mockSNS{}

	p := &Publisher{
		client: snsclient.New(mockSNS),
		routes: map[string]*Route{
			"ClockUpdate": {Arn: "foo"},
		},
	}

	t.Run("Attributes check", func(t *testing.T) {
		err := p.Publish(context.Background(), &clock.ClockUpdate{Time: nil})
		assert.NoError(t, err)
		assert.Len(t, mockSNS.publishCalls, 1)
		c := mockSNS.publishCalls[0]
		assert.Len(t, c.MessageAttributes, 2)

		assert.Equal(t, "ClockUpdate", *c.MessageAttributes["eventbus-event-type"].StringValue)

		n, err := strconv.Atoi(*c.MessageAttributes["eventbus-rand-e6"].StringValue)
		assert.NoError(t, err)
		assert.True(t, n >= 0 && n < 1e6)
	})
}

type mockSNS struct {
	snsiface.SNSAPI

	publishCalls []sns.PublishInput
}

func (s *mockSNS) PublishWithContext(ctx aws.Context, input *sns.PublishInput, opts ...request.Option) (*sns.PublishOutput, error) {
	s.publishCalls = append(s.publishCalls, *input)
	return &sns.PublishOutput{}, nil
}
