package e2e_utils

import (
	"encoding/json"
	"fmt"
	"log"
	"os"
	"reflect"
	"strings"
	"time"

	"github.com/davecgh/go-spew/spew"

	"github.com/aws/aws-sdk-go/aws"

	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/fatih/color"
	"golang.org/x/net/websocket"
)

type DiscoveryMessage struct {
	ChannelID string   `json:"channel_id"`
	Add       []string `json:"add"`
	Remove    []string `json:"remove"`
}

func (tester *eceImpl) TestTagsValidatorNormalPath() {
	connectPack := replaceTokenAndBroadcasterIDs(tester.UserToken, tester.BroadcasterIDs, StandardConnectionData)

	expectedMessage := []DiscoveryMessage{
		{
			ChannelID: tester.BroadcasterIDs,
			Add:       []string{TagUUIDsGearBox["Amara"], TagUUIDsGearBox["FL4K"]},
			Remove:    []string(nil),
		},
		{
			ChannelID: tester.BroadcasterIDs,
			Add:       []string(nil),
			Remove:    []string{TagUUIDsGearBox["FL4K"]},
		},
		{
			ChannelID: tester.BroadcasterIDs,
			Add:       []string{TagUUIDsGearBox["Moze"]},
			Remove: []string{
				TagUUIDsGearBox["Amara"],
				TagUUIDsGearBox["FL4K"],
				TagUUIDsGearBox["Zane"],
				TagUUIDsGearBox["Hunter"],
				TagUUIDsGearBox["Mayhem"],
			},
		},
	}

	dataPacks := []string{
		connectPack,
		StandardDeltaPackChangeMetadata,
		StandardRefresh,
	}

	tester.validateTagsValidatorSNSMessage(expectedMessage, dataPacks, "unwhitelisted tags")
}

func (tester *eceImpl) TestTagsValidatorNormalPathUnwhitelistedTags() {
	connectPack := replaceTokenAndBroadcasterIDs(tester.UserToken, tester.BroadcasterIDs, StandardConnectionData)
	connectPackWithUnwhitelistedTag := strings.Replace(connectPack, TagUUIDsGearBox["Amara"], "UnvalidatedTag", 1)

	expectedMessage := []DiscoveryMessage{
		{
			ChannelID: tester.BroadcasterIDs,
			Add:       []string{TagUUIDsGearBox["FL4K"]},
			Remove:    []string(nil),
		},
		{
			ChannelID: tester.BroadcasterIDs,
			Add:       []string{TagUUIDsGearBox["Amara"]},
			Remove:    []string{TagUUIDsGearBox["FL4K"]},
		},
		{
			ChannelID: tester.BroadcasterIDs,
			Add:       []string{TagUUIDsGearBox["Moze"]},
			Remove: []string{
				TagUUIDsGearBox["Amara"],
				TagUUIDsGearBox["FL4K"],
				TagUUIDsGearBox["Zane"],
				TagUUIDsGearBox["Hunter"],
				TagUUIDsGearBox["Mayhem"],
			},
		},
	}

	dataPacks := []string{
		connectPackWithUnwhitelistedTag,
		StandardDeltaPackChangeMetadata,
		StandardRefresh,
	}

	tester.validateTagsValidatorSNSMessage(expectedMessage, dataPacks, "unwhitelisted tags")
}

func convertDiscoveryMessageToStruct(message *sqs.Message) DiscoveryMessage {
	messageData := convertSNSMessage(message)
	discoveryMessage := DiscoveryMessage{}
	if err := json.Unmarshal([]byte(messageData.Message), &discoveryMessage); err != nil {
		color.Red("Could not parse sqs message")
		os.Exit(1)
	}

	return discoveryMessage
}

/**
tagsValidatorSNSWorker polls for SQS messages until it either finds all expected messages or times out
*/
func (tester *eceImpl) tagsValidatorSNSWorker(finished chan bool, expectedMessages []DiscoveryMessage) {
	go func() {
		found := map[int]DiscoveryMessage{}
		notFoundExpected := map[int]DiscoveryMessage{}
		notFoundActual := map[int]DiscoveryMessage{}
		retries := 0
		for {
			discoveryMessageList := tester.mustGetSQSMessages(devTagsSQS)

			for j, message := range discoveryMessageList {
				notFoundActual[j] = message

				for i, expectedMessage := range expectedMessages {
					if _, ok := found[i]; !ok {
						if reflect.DeepEqual(message, expectedMessage) {
							found[i] = message
							delete(notFoundExpected, j)
						} else {
							notFoundExpected[i] = expectedMessage
						}
					}
				}
			}

			if len(found) == len(expectedMessages) {
				finished <- true
				break
			}

			if retries >= (len(expectedMessages) * 5) {
				fmt.Println("Error: worker took too many attempts to find expected messages")
				color.White(fmt.Sprintf("Received: %+v\n", notFoundActual))
				color.Red(fmt.Sprintf("Still expecting: %+v\n", notFoundExpected))
				os.Exit(1)
			}

			retries++
			time.Sleep(3 * time.Second)
		}
	}()
}

func (tester *eceImpl) validateTagsValidatorSNSMessage(expectedMessages []DiscoveryMessage, dataPacks []string, name string) {
	finished := make(chan bool)

	// connect
	ws, err := tester.connect()
	if err != nil {
		log.Fatal(err)
	}

	defer closeConnection()("TestTagsValidatorNormalPath", ws)

	// spin up validation worker
	spinUpValidationWorker(validateOnlyConnect, ws)
	tester.tagsValidatorSNSWorker(finished, expectedMessages)

	for _, dataPack := range dataPacks {
		var data map[string]interface{}
		err = json.Unmarshal([]byte(dataPack), &data)
		if err != nil {
			fmt.Println("Error parsing datapack: ", name, err.Error())
		}

		err = websocket.JSON.Send(ws, data)
		if err != nil {
			fmt.Println("Error sending message: ", name, err.Error())
		}

		time.Sleep(1 * time.Second)
	}

	<-finished
}

func (tester *eceImpl) mustGetSQSMessages(qURL string) []DiscoveryMessage {
	svc := openSQSSession()

	result, err := svc.ReceiveMessage(&sqs.ReceiveMessageInput{
		AttributeNames: []*string{
			aws.String(sqs.MessageSystemAttributeNameSentTimestamp),
		},
		MessageAttributeNames: []*string{
			aws.String(sqs.QueueAttributeNameAll),
		},
		QueueUrl:            &qURL,
		MaxNumberOfMessages: aws.Int64(10),
		VisibilityTimeout:   aws.Int64(20),
		WaitTimeSeconds:     aws.Int64(0),
	})

	if err != nil {
		fmt.Println("Error", err)
		os.Exit(1)
		return nil
	}

	if len(result.Messages) == 0 {
		return nil
	}

	var discoveryMessageList []DiscoveryMessage
	for _, message := range result.Messages {
		discoveryMessageList = append(discoveryMessageList, convertDiscoveryMessageToStruct(message))
	}

	spew.Dump(result.Messages)
	// Remove messages from SQS
	removeSQSMessages(result.Messages, qURL)

	return discoveryMessageList
}
