package StarfruitTSMClient_test

import (
	"context"
	"errors"
	"sync/atomic"
	"testing"
	"time"

	"code.justin.tv/amzn/StarfruitTSMClient"
	"code.justin.tv/amzn/StarfruitTSMClient/internal/fakes"
	rpc "code.justin.tv/amzn/StarfruitTranscodeStateMgrTwirp"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestScan(t *testing.T) {
	t.Run("scan with no errors should return successfully", func(t *testing.T) {
		fakeScanner := &fakes.FakeTSMScanner{}

		calls := uint64(0)
		want := uint64(10000)
		fakeScanner.ScanTranscodesStub = func(ctx context.Context, req *rpc.ScanTranscodesRequest) (*rpc.ScanTranscodesResponse, error) {
			c := atomic.AddUint64(&calls, 1)
			if c > want {
				// stop paging
				return &rpc.ScanTranscodesResponse{
					LastTranscodeId: "",
				}, nil
			}

			transcode := &rpc.Transcode{}
			return &rpc.ScanTranscodesResponse{
				Transcodes:      []*rpc.Transcode{transcode},
				LastTranscodeId: "fakeid",
			}, nil
		}

		ps := StarfruitTSMClient.NewParallelScanner(fakeScanner, StarfruitTSMClient.DefaultConcurrency, "Magnus", 2*time.Second)
		transcodes, err := ps.Scan(context.Background())
		require.NoError(t, err)
		assert.Equal(t, int(want), len(transcodes))
	})

	t.Run("scan with a single error should still return successfully", func(t *testing.T) {
		fakeScanner := &fakes.FakeTSMScanner{}

		calls := uint64(0)
		want := uint64(10000)
		fakeScanner.ScanTranscodesStub = func(ctx context.Context, req *rpc.ScanTranscodesRequest) (*rpc.ScanTranscodesResponse, error) {
			c := atomic.AddUint64(&calls, 1)

			if c == 42 {
				return nil, errors.New("timeout")
			}

			// add 1 call for the single failed call that needed to be retried
			if c > want+1 {
				// stop paging
				return &rpc.ScanTranscodesResponse{
					LastTranscodeId: "",
				}, nil
			}

			transcode := &rpc.Transcode{}
			return &rpc.ScanTranscodesResponse{
				Transcodes:      []*rpc.Transcode{transcode},
				LastTranscodeId: "fakeid",
			}, nil
		}

		ps := StarfruitTSMClient.NewParallelScanner(fakeScanner, StarfruitTSMClient.DefaultConcurrency, "Magnus", 2*time.Second)
		transcodes, err := ps.Scan(context.Background())
		require.NoError(t, err)
		assert.Equal(t, int(want), len(transcodes))
	})

	t.Run("scan with errors should terminate", func(t *testing.T) {
		fakeScanner := &fakes.FakeTSMScanner{}
		fakeScanner.ScanTranscodesStub = func(ctx context.Context, req *rpc.ScanTranscodesRequest) (*rpc.ScanTranscodesResponse, error) {
			return nil, errors.New("timeout")
		}
		ps := StarfruitTSMClient.NewParallelScanner(fakeScanner, StarfruitTSMClient.DefaultConcurrency, "Magnus", 2*time.Second)
		_, err := ps.Scan(context.Background())
		require.Error(t, err)
	})
}

func TestGetByCustomerID(t *testing.T) {
	t.Run("get with no errors should return successfully", func(t *testing.T) {
		fakeGetter := &fakes.FakeTSMCustomerIDGetter{}

		calls := uint64(0)
		want := uint64(10000)
		fakeGetter.GetTranscodesByCustomerIDParallelStub = func(ctx context.Context, req *rpc.GetTranscodesByCustomerIDParallelRequest) (*rpc.GetTranscodesByCustomerIDParallelResponse, error) {
			c := atomic.AddUint64(&calls, 1)
			if c > want {
				// stop paging
				return &rpc.GetTranscodesByCustomerIDParallelResponse{
					LastTranscodeId: "",
				}, nil
			}

			transcode := &rpc.Transcode{}
			return &rpc.GetTranscodesByCustomerIDParallelResponse{
				Transcodes:      []*rpc.Transcode{transcode},
				LastTranscodeId: "fakeid",
			}, nil
		}

		ps := StarfruitTSMClient.NewParallelCustomerIDGetter(fakeGetter, 2*time.Second)
		transcodes, err := ps.GetByCustomerID(context.Background(), "twitch")
		require.NoError(t, err)
		assert.Equal(t, int(want), len(transcodes))
	})

	t.Run("scan with errors should terminate", func(t *testing.T) {
		fakeGetter := &fakes.FakeTSMCustomerIDGetter{}
		fakeGetter.GetTranscodesByCustomerIDParallelStub = func(ctx context.Context, req *rpc.GetTranscodesByCustomerIDParallelRequest) (*rpc.GetTranscodesByCustomerIDParallelResponse, error) {
			return nil, errors.New("timeout")
		}

		ps := StarfruitTSMClient.NewParallelCustomerIDGetter(fakeGetter, 2*time.Second)
		_, err := ps.GetByCustomerID(context.Background(), "twitch")
		require.Error(t, err)
	})

	t.Run("scan with empty result should be empty list, non-nil", func(t *testing.T) {
		fakeGetter := &fakes.FakeTSMCustomerIDGetter{}
		fakeGetter.GetTranscodesByCustomerIDParallelStub = func(ctx context.Context, req *rpc.GetTranscodesByCustomerIDParallelRequest) (*rpc.GetTranscodesByCustomerIDParallelResponse, error) {
			return &rpc.GetTranscodesByCustomerIDParallelResponse{}, nil
		}

		ps := StarfruitTSMClient.NewParallelCustomerIDGetter(fakeGetter, 2*time.Second)
		result, err := ps.GetByCustomerID(context.Background(), "twitch")
		require.NoError(t, err)
		require.NotNil(t, result)
		assert.Equal(t, len(result), 0)
	})

}
