package StarfruitSECProducer

import (
	"context"
	"errors"
	"sync"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
	"github.com/stretchr/testify/assert"
)

const (
	delayInNanoSec  = 80000
	inducedErrorMsg = "induced error"
)

// MockSQSClient Define a mock struct to be used in your unit tests of myFunc.
type MockSQSClient struct {
	sqsiface.SQSAPI
	msgSent     int
	delay       time.Duration
	induceError bool
	mu          sync.Mutex
}

func (m *MockSQSClient) messageSent() {
	m.mu.Lock()
	m.msgSent++
	m.mu.Unlock()
}

func (m *MockSQSClient) howManyMsgSent() int {
	m.mu.Lock()
	s := m.msgSent
	m.mu.Unlock()
	return s
}

func (m *MockSQSClient) SendMessageWithContext(ctx context.Context, msgInput *sqs.SendMessageInput, opts ...request.Option) (*sqs.SendMessageOutput, error) {
	var err error = nil
	time.Sleep(m.delay)
	if m.induceError {
		err = errors.New(inducedErrorMsg)
	} else {
		m.messageSent()
	}
	return nil, err
}

func TestBasicMessageBuffer(t *testing.T) {
	// Setup Test
	mockSvc := &MockSQSClient{}
	losDuration := time.Duration(10 * time.Second) // DefaultMsgLossDuration
	cfg1 := Config{
		SqsApi:          mockSvc,
		QueueURL:        "QueueURL",
		BufferSize:      2,
		MsgLossDuration: losDuration,
		LoggingCallBack: func(s string) {
			t.Log("CBF1: " + s) // adding prefix so as we can readily recognize the output among all other unit test outputs
		},
	}
	psp := createProtoSQSProducer(cfg1)
	// now start sending messages
	psp.Send(nil, NewTestPayload())
	psp.Send(nil, NewTestPayload())
	// since actual sending is done by another go routine, some time may be required
	time.Sleep(time.Duration(delayInNanoSec))

	if assert.Equal(t, mockSvc.msgSent, 2) {
		t.Log("asserted that 2 messages were sent")
	}

	// we assume that all messages are flushed out by now
	// let us set the delay - 8 milliseconds (the number is in nanoseconds)
	mockSvc.delay = 8000
	firstMsg := NewTestPayload()
	secondMsg := NewTestPayload()
	thirdMsg := NewTestPayload()
	fourthMsg := NewTestPayload()

	// now send all 4
	psp.Send(nil, firstMsg)
	psp.Send(nil, secondMsg)
	psp.Send(nil, thirdMsg)
	psp.Send(nil, fourthMsg)

	// so by now we should have lost at the most 2 messages
	ml := psp.chnBuf.msgLoss()
	// assert that at the most 2 messages are lost
	if assert.LessOrEqual(t, ml, 2) {
		t.Logf("asserted number of messages lost: %d\n", ml)
	}

	maxWaitCycles := 128
	cyclesWaited := 0
	for cyclesWaited < maxWaitCycles {
		time.Sleep(time.Duration(delayInNanoSec * 4))
		cyclesWaited += 4
	}

	if assert.LessOrEqual(t, mockSvc.howManyMsgSent(), 6-ml) {
		t.Logf("asserted that total %d Messages are sent\n", 6-ml)
	}

	// induce error
	mockSvc.induceError = true
	errorMsg := ""
	prefix := "CBF2: "
	var wg sync.WaitGroup

	cfg2 := Config{
		SqsApi:          mockSvc,
		QueueURL:        "QueueURL",
		BufferSize:      2,
		MsgLossDuration: losDuration,
		LoggingCallBack: func(s string) {
			errorMsg = prefix + s
			//t.Log(errorMsg)
			wg.Done()
		},
	}
	psp = createProtoSQSProducer(cfg2)

	fifthMsg := NewTestPayload()
	wg.Add(1)
	psp.Send(nil, fifthMsg)
	wg.Wait()
	if assert.Equal(t, errorMsg, prefix+"[SEC-Event-Producer] error sending event to SEC: "+inducedErrorMsg) {
		t.Logf("asserted that error message received as expected.\n")
	}
}

func TestMessageLossRecording(t *testing.T) {
	// Setup Test
	mockSvc := &MockSQSClient{}
	losDuration := time.Duration(2 * time.Second) // DefaultMsgLossDuration
	errorMsg := ""
	prefix := "CBF3: "

	cfg3 := Config{
		SqsApi:          mockSvc,
		QueueURL:        "QueueURL",
		BufferSize:      2,
		MsgLossDuration: losDuration,
		LoggingCallBack: func(s string) {
			errorMsg = prefix + s
			//t.Log(errorMsg)
		},
	}
	psp := createProtoSQSProducer(cfg3)

	mockSvc.delay = time.Duration(1 * time.Second)
	mockSvc.induceError = false

	sixthMsg := NewTestPayload()
	seventhMsg := NewTestPayload()
	eighthMsg := NewTestPayload()
	ninethMsg := NewTestPayload()

	psp.Send(nil, sixthMsg)
	psp.Send(nil, seventhMsg)
	psp.Send(nil, eighthMsg)
	psp.Send(nil, ninethMsg)

	time.Sleep(2 * losDuration)
	if assert.Greater(t, len(errorMsg), 0) {
		t.Logf("asserted that message loss was recorded.\n")
	}
}

func TestNoLoiggingCallBack(t *testing.T) {
	// Setup Test
	mockSvc := &MockSQSClient{}
	losDuration := time.Duration(10 * time.Second) // DefaultMsgLossDuration
	cfg := Config{
		SqsApi:          mockSvc,
		QueueURL:        "QueueURL",
		BufferSize:      2,
		MsgLossDuration: losDuration,
		// no logging call back set
	}
	var pi interface{} = NewProtoSQSProducerByConfig(cfg)

	psp := pi.(ProtobufSQSProducer)
	// now start sending messages
	psp.Send(nil, NewTestPayload())
	psp.Send(nil, NewTestPayload())
	assert.True(t, true)
}
