package e2e_utils

import (
	"encoding/json"
	"fmt"
	"log"
	"os"
	"time"

	"github.com/fatih/color"
	"github.com/go-test/deep"

	. "code.justin.tv/devhub/twitch-e2-ingest/models"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"

	"github.com/aws/aws-sdk-go/service/sqs"
)

type SNSMessage struct {
	Message string `json:"Message"`
}

func openSQSSession() *sqs.SQS {
	sess, err := session.NewSession()
	if err != nil {
		fmt.Println("Error connecting to sqs ", err)
		os.Exit(1)
	}

	conf := &aws.Config{Region: aws.String("us-west-2")}
	return sqs.New(sess, conf)
}

func removeSQSMessages(messages []*sqs.Message, qURL string) {
	svc := openSQSSession()

	var entries []*sqs.DeleteMessageBatchRequestEntry
	for _, message := range messages {
		entries = append(entries, &sqs.DeleteMessageBatchRequestEntry{
			Id:            message.MessageId,
			ReceiptHandle: message.ReceiptHandle,
		})
	}

	_, err := svc.DeleteMessageBatch(&sqs.DeleteMessageBatchInput{
		QueueUrl: &qURL,
		Entries:  entries,
	})

	if err != nil {
		fmt.Println("Delete Error", err)
		return
	}
}

// Receive a list of non-nil SQS messages
func receiveSQSMessages(qURL string) []sqs.Message {
	svc := openSQSSession()

	// Store generic SQS messages and make sure it is not nil
	var messageList []sqs.Message

	// Short poll is the default behavior where a weighted random set of machines
	// is sampled on a ReceiveMessage call. Thus, only the messages on the sampled
	// machines are returned.
	// Go through a for loop to ensure high probability all messages are polled.
	for i := 0; i < 5; i++ {
		result, _ := svc.ReceiveMessage(&sqs.ReceiveMessageInput{
			AttributeNames: []*string{
				aws.String(sqs.MessageSystemAttributeNameSentTimestamp),
			},
			MessageAttributeNames: []*string{
				aws.String(sqs.QueueAttributeNameAll),
			},
			QueueUrl:            &qURL,
			MaxNumberOfMessages: aws.Int64(10),
			VisibilityTimeout:   aws.Int64(20),
			WaitTimeSeconds:     aws.Int64(0),
		})

		for _, message := range result.Messages {
			if message != nil {
				messageList = append(messageList, *message)
			}
		}

		if len(result.Messages) != 0 {
			// Remove messages from SQS
			removeSQSMessages(result.Messages, qURL)
		}
	}

	return messageList
}

func validateSQSSequence(expectedMessages []GameFullData, actualMessages []sqs.Message) {
	if len(actualMessages) != len(expectedMessages) {
		log.Fatalf("expected message length %d not equal to actual message received %d", len(expectedMessages), len(actualMessages))
	}

	orderedMessages := make([]GameFullData, len(expectedMessages))
	for _, actualMessage := range actualMessages {
		actualMessage := convertFullDataMessageToStruct(&actualMessage)
		orderedMessages[int(actualMessage.MessageID)-1] = actualMessage
	}

	for i, expectedMessage := range expectedMessages {
		compareFullData(i, expectedMessage, orderedMessages[i])
	}
}

func convertFullDataMessageToStruct(message *sqs.Message) GameFullData {
	messageData := convertSNSMessage(message)
	data := GameFullData{}
	if err := json.Unmarshal([]byte(messageData.Message), &data); err != nil {
		color.Red("Could not parse sqs message")
		os.Exit(1)
	}

	return data
}

func convertSNSMessage(message *sqs.Message) SNSMessage {
	if message == nil {
		log.Fatal("message received is nil from sqs")
	}

	messageData := SNSMessage{}
	if err := json.Unmarshal([]byte(*message.Body), &messageData); err != nil {
		color.Red("Could not parse sqs message")
		os.Exit(1)
	}

	return messageData
}

// Validate all fields are to expected
func compareFullData(i int, expected, actual GameFullData) {
	if expected.GameID != actual.GameID {
		log.Fatalf("game id not equal, expected: %s, actual %s", expected.GameID, actual.GameID)
	}

	if deep.Equal(expected.BroadcasterIDs, actual.BroadcasterIDs) != nil {
		log.Fatalf("broadcaster id not equal, expected: %v, actual %v", expected.BroadcasterIDs, actual.BroadcasterIDs)
	}

	if expected.ClientID != actual.ClientID {
		log.Fatalf("client id not equal, expected: %s, actual %s", expected.ClientID, actual.ClientID)
	}

	// Test connection id
	expectedConnectionTime := time.Unix(expected.ConnectionID, 0)
	actualConnectionTime := time.Unix(actual.ConnectionID, 0)
	// The time difference should not be more than one minute
	if expectedConnectionTime.Sub(actualConnectionTime) > time.Minute {
		log.Fatalf("connection time not equal, expected: %s, actual %s", expectedConnectionTime.String(), actualConnectionTime.String())
	}

	// Test connection id
	actualTime := time.Unix(actual.Time, 0)
	// The time difference should not be more than one minute
	if expectedConnectionTime.Sub(actualTime) > time.Minute {
		log.Fatalf("client id not equal, expected: %s, actual %s", expectedConnectionTime.String(), actualTime.String())
	}

	if expected.Env != actual.Env {
		log.Fatalf("env not equal, expected: %s, actual %s", expected.Env, actual.Env)

	}

	if int64(i+1) != actual.MessageID {
		log.Fatalf("message id not equal, expected: %d, actual %8d", i+1, actual.MessageID)
	}

	// Validate actual data field
	if deep.Equal(expected.Data, actual.Data) != nil {
		log.Fatalf("data not equal, expected: %v, actual %v", expected.Data, actual.Data)
	}
}
