package main

import (
	"flag"
	"fmt"
	"log"
	"math"
	"strconv"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/autoscaling"
	"github.com/aws/aws-sdk-go/service/ecs"
	"github.com/pkg/errors"
)

const region = "us-west-2"

const (
	// Start will spins up the  mock service and loadtest clients.
	startCmd = "start"
	// Stop will spin down the mock service and loadtest clients.
	stopCmd = "stop"
	// Run will start a loadtest, wait until it has completed(a set timeout), and then stops the loadtest.
	runCmd = "run"
)

const (
	usersPerMockServiceHost = 40000
	usersPerLoadtestHost    = 30000
)

type config struct {
	// Whether a load test is being started or stopped. Values are START or STOP.
	cmd string

	testParams *testParameters

	mockService *ecsParameters

	loadTest *ecsParameters
}

type testParameters struct {
	// How many users/clients each load test should spin up.
	numUsers int
	// Time between each user being spun up in the load test.
	userIntervalMs int
	// The number of messages that mock_service will send, and the user/client will expect.
	numMessages int
	// Interval at which mock_service will send a message to each client.
	msgIntervalSeconds int
	// Interval at which each user/client will send a "subscription" to the websocket edge.
	subscriptionIntervalSeconds int
}

// Returns the total time that this test will take.
func (t *testParameters) totalTime() time.Duration {
	spinupTime := time.Duration(t.userIntervalMs*t.numUsers) * time.Millisecond
	msgTime := time.Duration(t.msgIntervalSeconds*t.numMessages) * time.Second
	bufferTime := time.Second * 60
	return spinupTime + msgTime + bufferTime
}

type ecsParameters struct {
	asgName     string
	clusterName string
	serviceName string
}

func main() {
	conf := &config{testParams: &testParameters{}, mockService: &ecsParameters{}, loadTest: &ecsParameters{}}
	parseFlags(conf)
	validateConfig(conf)

	awsSession := session.Must(session.NewSession(&aws.Config{
		Region: aws.String(region),
	}))

	switch conf.cmd {
	case startCmd:
		startLoadTest(conf, awsSession)
	case stopCmd:
		stopLoadTest(conf, awsSession)
	case runCmd:
		runLoadTest(conf, awsSession)
	}
}

// Parses command line flags and populates the provided config.
func parseFlags(c *config) {
	msg := fmt.Sprintf("Whether a load test is being started or stopped. Acceptable values are '%s' or '%s'", startCmd, stopCmd)
	flag.StringVar(&c.cmd, "command", "", msg)

	flag.IntVar(&c.testParams.numUsers, "numUsers", 30000, "How many users will be in this load test.")
	flag.IntVar(&c.testParams.userIntervalMs, "userIntervalMs", 20, "Duration, in ms, between each user being spun up in the load test.")
	flag.IntVar(&c.testParams.numMessages, "numMessages", 100, "Number of messages mock_service will send to each user, and how many each user will expect.")
	flag.IntVar(&c.testParams.msgIntervalSeconds, "msgIntervalSeconds", 15, "Duration, in seconds, between each message that mock_service sends to each user.")
	flag.IntVar(&c.testParams.subscriptionIntervalSeconds, "subscriptionIntervalSeconds", 60, "Duration, in seconds, between each subscription message that each user sends to the edge.")

	flag.StringVar(&c.mockService.asgName, "mockServiceAsgName", "", "The name of the mock service asg.")
	flag.StringVar(&c.mockService.clusterName, "mockServiceClusterName", "", "The name of the mock service ecs cluster.")
	flag.StringVar(&c.mockService.serviceName, "mockServiceServiceName", "", "The name of the mock service ecs service.")

	flag.StringVar(&c.loadTest.asgName, "loadTestAsgName", "", "The name of the load test asg.")
	flag.StringVar(&c.loadTest.clusterName, "loadTestClusterName", "", "The name of the load test ecs cluster.")
	flag.StringVar(&c.loadTest.serviceName, "loadTestServiceName", "", "The name of the load test ecs service.")

	flag.Parse()
}

func validateConfig(c *config) {
	switch {
	case c.cmd != startCmd && c.cmd != stopCmd && c.cmd != runCmd:
		log.Fatal("Acceptable commands are start, stop, and run.")
	case c.testParams.numUsers <= 0 || c.testParams.userIntervalMs <= 0 || c.testParams.numMessages <= 0 ||
		c.testParams.msgIntervalSeconds <= 0 || c.testParams.subscriptionIntervalSeconds <= 0:
		log.Fatal("Integer parameters must be greater than 0.")
	case c.mockService.serviceName == "" || c.mockService.clusterName == "" || c.mockService.asgName == "":
		log.Fatal("Must pass in service, cluster, and asg name for the mock service.")
	case c.loadTest.serviceName == "" || c.loadTest.clusterName == "":
		log.Fatal("Must pass in service and cluster for the loadtest service.")
	}
}

func startLoadTest(c *config, awsSession *session.Session) {
	log.Printf("Starting load test with %d users...", c.testParams.numUsers)

	ecsClient := ecs.New(awsSession)

	// TODO(scott): Consider messaging rates when determining the number of hosts needed.
	mockServiceHostsNeeded := int64(math.Ceil(float64(c.testParams.numUsers) / float64(usersPerMockServiceHost)))
	loadtestHostsNeeded := int64(math.Ceil(float64(c.testParams.numUsers) / float64(usersPerLoadtestHost)))

	// Update the mock service.
	newTaskDefArn, err := updateTaskDefinition(ecsClient, c.testParams, c.mockService)
	check(err, "updating mock_service task definition")
	_, err = updateService(ecsClient, c.mockService, newTaskDefArn, mockServiceHostsNeeded)
	check(err, "updating mock_service service")

	// Scale up the ASG so the tasks can be scheduled.
	err = setAsgCount(awsSession, c.mockService.asgName, mockServiceHostsNeeded+extraHosts(mockServiceHostsNeeded))
	check(err, "setting asg count")
	log.Printf("Waiting for mock service to deploy...")
	waitForService(ecsClient, c.mockService, newTaskDefArn)
	log.Println("Mock service has been deployed!")

	// Update the load test.
	newTaskDefArn, err = updateTaskDefinition(ecsClient, c.testParams, c.loadTest)
	check(err, "updating load_test task definition")
	_, err = updateService(ecsClient, c.loadTest, newTaskDefArn, loadtestHostsNeeded)
	check(err, "updating load_test service")
	log.Printf("Load test is now spinning up.")
}

// Sets the desired count of loadtest and mock services to 0 and scales down the mock service asg.
func stopLoadTest(c *config, awsSession *session.Session) {
	log.Println("Stopping load test...")

	ecsClient := ecs.New(awsSession)
	err := updateServiceCount(ecsClient, c.loadTest.clusterName, c.loadTest.serviceName, 0)
	check(err, "spinning down loadtest")
	err = updateServiceCount(ecsClient, c.mockService.clusterName, c.mockService.serviceName, 0)
	check(err, "spinning down mock service")
	err = setAsgCount(awsSession, c.mockService.asgName, 0)
	check(err, "spinning down mock service ASG")

	log.Println("all done.")
}

func runLoadTest(c *config, awsSession *session.Session) {
	startLoadTest(c, awsSession)

	totalTime := c.testParams.totalTime()
	timeWaited := time.Duration(0)
	waitDuration := time.Second * 30

	for {
		<-time.After(waitDuration)
		timeWaited += waitDuration

		remainingTime := totalTime - timeWaited
		if remainingTime <= time.Second {
			break
		}
		log.Printf("Test is running for another %d seconds.", int(remainingTime.Seconds()))
	}

	log.Printf("Test is complete.")

	stopLoadTest(c, awsSession)
}

// Sets the desired count, minsize, and maxsize to `count` for the given asg.
func setAsgCount(awsSession *session.Session, asgName string, count int64) error {
	client := autoscaling.New(awsSession)
	in := &autoscaling.UpdateAutoScalingGroupInput{
		AutoScalingGroupName: &asgName,
		DesiredCapacity:      &count,
		MinSize:              &count,
		MaxSize:              &count,
	}
	_, err := client.UpdateAutoScalingGroup(in)
	return err
}

// describeTaskDefintion by the specified identifier.
// The identifier can be a "family:revision" form, or a full Amazon Resource Name (ARN)
// of the task definition to describe.
func describeTaskDefinition(client *ecs.ECS, id string) (*ecs.TaskDefinition, error) {
	in := &ecs.DescribeTaskDefinitionInput{TaskDefinition: &id}
	out, err := client.DescribeTaskDefinition(in)
	if err != nil {
		return nil, err
	}

	return out.TaskDefinition, nil
}

// getCurrentTaskDefintion retrieves the current task definition for the service running in the cluster.
func getCurrentTaskDefinition(client *ecs.ECS, ecsParams *ecsParameters, arn string) (*ecs.TaskDefinition, error) {
	var taskDefARN string

	// Find the ARN of the task definition.
	switch {
	case ecsParams.serviceName != "":
		service, err := describeService(client, ecsParams)
		if err != nil {
			return nil, err
		}

		if service.TaskDefinition == nil {
			return nil, fmt.Errorf("service %q does not have an associated task defintion", ecsParams.serviceName)
		}

		taskDefARN = *service.TaskDefinition

	case arn != "":
		taskDefARN = arn
	default:
		return nil, errors.New("a service name or task definition identifier is required")
	}

	return describeTaskDefinition(client, taskDefARN)
}

// Returns the service with the provided name.
func describeService(client *ecs.ECS, ecsParams *ecsParameters) (*ecs.Service, error) {
	in := &ecs.DescribeServicesInput{
		Cluster:  &ecsParams.clusterName,
		Services: []*string{&ecsParams.serviceName},
	}
	out, err := client.DescribeServices(in)
	if err != nil {
		return nil, err
	}

	if len(out.Services) == 0 {
		return nil, fmt.Errorf("unable to find service %q in cluster %q", ecsParams.serviceName, ecsParams.clusterName)
	}

	// Assume a single service in the response, since we requested one service.
	return out.Services[0], nil
}

// Applies the provided test variables to the container definition by environment config.
func applyEnvironmentVariables(t *testParameters, containerDef *ecs.ContainerDefinition) {
	for _, e := range containerDef.Environment {
		switch {
		case *e.Name == "NUM_MESSAGES":
			e.Value = strPtr(strconv.Itoa(t.numMessages))
		case *e.Name == "MESSAGE_INTERVAL_SECONDS":
			e.Value = strPtr(strconv.Itoa(t.msgIntervalSeconds))
		case *e.Name == "MAX_USERS":
			e.Value = strPtr(strconv.Itoa(usersPerLoadtestHost))
		case *e.Name == "SUBSCRIPTION_INTERVAL_SECONDS":
			e.Value = strPtr(strconv.Itoa(t.subscriptionIntervalSeconds))
		case *e.Name == "USER_INTERVAL_MILLISECONDS":
			e.Value = strPtr(strconv.Itoa(t.userIntervalMs))
		}
	}
}

// registerTaskDefinition registers a new task definition in the Elastic Container Service (ECS).
func registerTaskDefinition(client *ecs.ECS, next *ecs.TaskDefinition) (*ecs.TaskDefinition, error) {
	in := &ecs.RegisterTaskDefinitionInput{
		ContainerDefinitions:    next.ContainerDefinitions,
		Cpu:                     next.Cpu,
		ExecutionRoleArn:        next.ExecutionRoleArn,
		Family:                  next.Family,
		Memory:                  next.Memory,
		NetworkMode:             next.NetworkMode,
		PlacementConstraints:    next.PlacementConstraints,
		RequiresCompatibilities: next.RequiresCompatibilities,
		TaskRoleArn:             next.TaskRoleArn,
		Volumes:                 next.Volumes,
	}
	out, err := client.RegisterTaskDefinition(in)
	if err != nil {
		return nil, err
	}

	return out.TaskDefinition, nil
}

// Updates the task definition & desired count of the service.
func updateService(client *ecs.ECS, ecsParams *ecsParameters, taskDefARN string, desiredCount int64) (*ecs.Service, error) {
	in := &ecs.UpdateServiceInput{
		Cluster:        &ecsParams.clusterName,
		Service:        &ecsParams.serviceName,
		DesiredCount:   &desiredCount,
		TaskDefinition: &taskDefARN,
	}

	out, err := client.UpdateService(in)
	if err != nil {
		return nil, err
	}

	return out.Service, nil
}

// Updates the desired count of the service.
func updateServiceCount(client *ecs.ECS, cluster string, service string, desiredCount int64) error {
	in := &ecs.UpdateServiceInput{
		Cluster:      &cluster,
		Service:      &service,
		DesiredCount: &desiredCount,
	}

	_, err := client.UpdateService(in)
	return err
}

// Registers a new task definition that is identical to the current task definition used by the service, but has its
// environment test config overwritten with the values in `testParams`.
func updateTaskDefinition(client *ecs.ECS, testParams *testParameters, ecsParams *ecsParameters) (string, error) {
	taskDef, err := getCurrentTaskDefinition(client, ecsParams, "")
	check(err, "getting task definition")

	if len(taskDef.ContainerDefinitions) != 1 {
		return "", fmt.Errorf("Expected exactly 1 container definition in the task definition.")
	}
	containerDef := taskDef.ContainerDefinitions[0]
	applyEnvironmentVariables(testParams, containerDef)

	newTaskDef, err := registerTaskDefinition(client, taskDef)
	if err != nil {
		return "", err
	}

	return *newTaskDef.TaskDefinitionArn, nil
}

// Returns true if all tasks use the specified task definition.
func allTasksUseARN(tt []*ecs.Task, taskDefArn string) bool {
	for _, t := range tt {
		if t == nil || t.TaskDefinitionArn == nil {
			continue
		}

		if *t.TaskDefinitionArn != taskDefArn {
			return false
		}
	}

	return true
}

// Returns a list of all the running tasks in the provided service.
func getRunningTasks(client *ecs.ECS, ecsParams *ecsParameters) ([]*ecs.Task, error) {
	var tasks []*ecs.Task
	var nextToken *string
	var didMultiFetch bool

	for {
		list, err := client.ListTasks(&ecs.ListTasksInput{
			Cluster:     &ecsParams.clusterName,
			ServiceName: &ecsParams.serviceName,
			NextToken:   nextToken,
		})
		if err != nil {
			return nil, err
		}
		if len(list.TaskArns) == 0 {
			return []*ecs.Task{}, nil
		}

		log.Printf("Fetching %d running tasks...", len(list.TaskArns))
		out, err := client.DescribeTasks(&ecs.DescribeTasksInput{
			Cluster: &ecsParams.clusterName,
			Tasks:   list.TaskArns,
		})
		if err != nil {
			return nil, err
		}

		tasks = append(tasks, out.Tasks...)
		if list.NextToken == nil {
			break
		}

		log.Print("Fetching more tasks...")
		nextToken = list.NextToken
		didMultiFetch = true
	}

	if didMultiFetch {
		log.Printf("Found %d total running tasks.", len(tasks))
	}

	return tasks, nil
}

// Blocks until all tasks in the provided service are using the specified task definition.
func waitForService(client *ecs.ECS, ecsParams *ecsParameters, newTaskDefArn string) {
	for {
		tasks, err := getRunningTasks(client, ecsParams)
		if len(tasks) == 0 {
			log.Printf("No tasks running yet...")
			<-time.After(10 * time.Second)
			continue
		}
		check(err, "getting running tasks")

		if allTasksUseARN(tasks, newTaskDefArn) {
			break
		}

		log.Printf("Still waiting for new task to finish running...")
		<-time.After(10 * time.Second)
	}
	log.Printf("service is updated & running.")
}

func strPtr(s string) *string {
	return &s
}

// extraHosts returns the number of extra hosts needed in an ASG to allow for speedy service deploys.
func extraHosts(hostsNeeded int64) int64 {
	switch {
	case hostsNeeded < 5:
		return 2
	case hostsNeeded < 10:
		return 3
	case hostsNeeded < 20:
		return 5
	default:
		return 7
	}
}

// check if the error has a value, and exit the program if so.
func check(err error, msg string) {
	if err == nil {
		return
	}

	log.Fatal(errors.Wrap(err, msg))
}
