package main

import (
	"context"
	"errors"
	"fmt"
	"log"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/ec2"
	"github.com/aws/aws-sdk-go/service/ecs"
)

// updateStatsite registers a new task definition using the provided imageID, or the current imageID(if not provided).
// It will then update the statsite service to use the new task def, and wait until `desiredCount` new tasks are running.
func updateStatsite(ecsClient *ecs.ECS, conf *ServiceConfig) (string, error) {
	updateServiceIn := &ecs.UpdateServiceInput{
		Service: &conf.serviceName,
		Cluster: &conf.clusterName,
	}

	// Create a new task definition if needed and add it to the update service request.
	log.Println("Creating new statsite task definition...")
	currentDef, err := describeTaskDefinition(ecsClient, conf.taskDefArn)
	if err != nil {
		return "", err
	}

	newDefArn, err := registerTaskDefinition(ecsClient, conf, currentDef, conf.imageTag, nil)
	if err != nil {
		return "", err
	}
	updateServiceIn.TaskDefinition = aws.String(newDefArn)

	log.Println("Updating statsite service...")
	_, err = ecsClient.UpdateService(updateServiceIn)
	if err != nil {
		return "", err
	}

	log.Println("Waiting for statsite deploy to complete...")
	err = waitUntilDeployCompletes(ecsClient, conf)
	if err != nil {
		return "", err
	}

	log.Println("Statsite deploy complete.")
	return newDefArn, nil
}

// updateStatsdProxy registers a new task definition using the provided imageID, or the current imageID(if not provided).
// The new task definition will have an updated STATSD_HOSTS env var with the provided hosts.
// It will then update the statsite service to use the new task def, and wait until `desiredCount` new tasks are running.
func updateStatsdProxy(ecsClient *ecs.ECS, conf *ServiceConfig, statsdHosts []string) error {
	if len(statsdHosts) == 0 {
		return errors.New("statsd proxy did not receive any statsdHosts")
	}

	// Create a new task definition if needed and add it to the update service request.
	currentDef, err := describeTaskDefinition(ecsClient, conf.taskDefArn)
	if err != nil {
		return err
	}

	log.Println("Creating new statsd proxy task definition...")
	envVars := map[string]string{
		"ADDRESSES": strings.Join(statsdHosts, ","),
	}
	newDefArn, err := registerTaskDefinition(ecsClient, conf, currentDef, conf.imageTag, envVars)
	if err != nil {
		return err
	}

	log.Println("Updating statsd proxy service...")
	updateServiceIn := &ecs.UpdateServiceInput{
		Service:        &conf.serviceName,
		Cluster:        &conf.clusterName,
		TaskDefinition: &newDefArn,
	}
	_, err = ecsClient.UpdateService(updateServiceIn)
	if err != nil {
		return err
	}

	log.Println("Waiting for statsd proxy deploy to complete...")
	err = waitUntilDeployCompletes(ecsClient, conf)
	if err != nil {
		return err
	}

	log.Println("statsd proxy deploy complete.")
	return nil
}

// describeTaskDefinition returns the taskDefinition identified by `taskDefArn`.
func describeTaskDefinition(ecsClient *ecs.ECS, taskDefArn string) (*ecs.TaskDefinition, error) {
	input := &ecs.DescribeTaskDefinitionInput{TaskDefinition: &taskDefArn}
	output, err := ecsClient.DescribeTaskDefinition(input)
	if err != nil {
		return nil, err
	}
	return output.TaskDefinition, nil
}

// updateContainerEnvVars replaces container env vars with new vars as specified by map.
// Returns an error on env vars that aren't currently on the container def.
func updateContainerEnvVars(containerDef *ecs.ContainerDefinition, envVars map[string]string) error {
	if len(envVars) == 0 {
		return nil
	}

	// Create a map of the current env vars from name:*ecs.KeyValue
	containerVars := make(map[string]*ecs.KeyValuePair)
	for _, kv := range containerDef.Environment {
		containerVars[*kv.Name] = kv
	}

	for key, val := range envVars {
		ecsKeyValue, ok := containerVars[key]
		if !ok {
			return fmt.Errorf("%s is not a current environment variable on the container definition with image %s", key, *containerDef.Image)
		}
		ecsKeyValue.SetValue(val)
	}

	return nil
}

// registerTaskDefinition creates a new task definition based on srcTaskDef. If provided imageID will be used on the new task definition.
// Environment variables will be updated according to the envVars map. Currently nonexistent env vars return an error.
func registerTaskDefinition(ecsClient *ecs.ECS, conf *ServiceConfig, srcTaskDef *ecs.TaskDefinition, imageTag string, envVars map[string]string) (string, error) {
	input := &ecs.RegisterTaskDefinitionInput{
		ContainerDefinitions:    srcTaskDef.ContainerDefinitions,
		Cpu:                     srcTaskDef.Cpu,
		ExecutionRoleArn:        srcTaskDef.ExecutionRoleArn,
		Family:                  srcTaskDef.Family,
		IpcMode:                 srcTaskDef.IpcMode,
		Memory:                  srcTaskDef.Memory,
		NetworkMode:             srcTaskDef.NetworkMode,
		PidMode:                 srcTaskDef.PidMode,
		PlacementConstraints:    srcTaskDef.PlacementConstraints,
		RequiresCompatibilities: srcTaskDef.RequiresCompatibilities,
		TaskRoleArn:             srcTaskDef.TaskRoleArn,
		Volumes:                 srcTaskDef.Volumes,
	}

	// Find the container definition for the app.
	var appContainerDef *ecs.ContainerDefinition
	for _, containerDefinition := range input.ContainerDefinitions {
		if *containerDefinition.Name == *srcTaskDef.Family {
			appContainerDef = containerDefinition
		}
	}
	if appContainerDef == nil {
		return "", fmt.Errorf("could not find a container definition with name %s in the source task definition", *srcTaskDef.Family)
	}

	// Replace the image.
	if imageTag != "" {
		currentImage := *appContainerDef.Image
		sp := strings.Split(currentImage, ":")
		newImage := fmt.Sprintf("%s:%s", sp[0], imageTag)
		appContainerDef.SetImage(newImage)
	}

	// Update env vars if needed.
	err := updateContainerEnvVars(appContainerDef, envVars)
	if err != nil {
		return "", err
	}

	output, err := ecsClient.RegisterTaskDefinition(input)
	if err != nil {
		return "", err
	}

	return *output.TaskDefinition.TaskDefinitionArn, nil
}

// waitUntilDeployCompletes blocks until there are desiredCount tasks running on the primary deployment of the service.
// Times out and returns an error after 5 minutes.
func waitUntilDeployCompletes(ecsClient *ecs.ECS, conf *ServiceConfig) error {
	isDeployComplete := func() (bool, error) {
		out, err := ecsClient.DescribeServices(&ecs.DescribeServicesInput{
			Cluster:  &conf.clusterName,
			Services: []*string{&conf.serviceName},
		})
		if err != nil {
			return false, err
		}
		if len(out.Services) != 1 {
			return false, fmt.Errorf("expected one service in cluster %s, received %d", conf.clusterName, len(out.Services))
		}

		service := out.Services[0]
		if len(service.Deployments) == 1 {
			return true, nil
		}

		// Find the current deploy.
		var currDeploy *ecs.Deployment
		for _, deploy := range service.Deployments {
			if *deploy.Status == "PRIMARY" {
				currDeploy = deploy
			}
		}

		if *currDeploy.DesiredCount == *currDeploy.RunningCount && *currDeploy.PendingCount == 0 {
			return true, nil
		}

		return false, nil
	}

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
	defer cancel()

	ticker := time.NewTicker(time.Second * 10)
	for {
		select {
		case <-ticker.C:
			complete, err := isDeployComplete()
			if err != nil {
				return err
			}
			if complete {
				return nil
			}
		case <-ctx.Done():
			return fmt.Errorf("the deploy of %s did not complete in time", conf.serviceName)
		}
	}
}

// Returns all tasks currently running on the provided service and cluster. Handles pagination.
func getTasks(ecsClient *ecs.ECS, conf *ServiceConfig) ([]*ecs.Task, error) {
	fetchPage := func(nextToken *string) ([]*ecs.Task, *string, error) {
		listTasksOut, err := ecsClient.ListTasks(&ecs.ListTasksInput{
			Cluster:     &conf.clusterName,
			ServiceName: &conf.serviceName,
		})
		if err != nil {
			return nil, nil, err
		}

		out, err := ecsClient.DescribeTasks(&ecs.DescribeTasksInput{
			Cluster: &conf.clusterName,
			Tasks:   listTasksOut.TaskArns,
		})
		if err != nil {
			return nil, nil, err
		}

		return out.Tasks, listTasksOut.NextToken, nil
	}

	tasks, nextToken, err := fetchPage(nil)
	if err != nil {
		return nil, err
	}

	var t []*ecs.Task
	for nextToken != nil {
		t, nextToken, err = fetchPage(nextToken)
		if err != nil {
			return nil, err
		}

		tasks = append(tasks, t...)
	}

	return tasks, nil
}

// getContainerInstanceIPs returns a map of containerInstanceARN to host private IP.
func getContainerInstanceIPs(ecsClient *ecs.ECS, ec2Client *ec2.EC2, conf *ServiceConfig) (map[string]string, error) {
	// Describe all container instances.
	listIn := &ecs.ListContainerInstancesInput{
		Cluster: &conf.clusterName,
	}
	listOut, err := ecsClient.ListContainerInstances(listIn)
	if err != nil {
		return nil, err
	}
	describeIn := &ecs.DescribeContainerInstancesInput{
		Cluster:            &conf.clusterName,
		ContainerInstances: listOut.ContainerInstanceArns,
	}
	describeOut, err := ecsClient.DescribeContainerInstances(describeIn)
	if err != nil {
		return nil, err
	}

	// Builds list of instanceIDs to fetch, and map of instanceID to container instance ARN.
	instanceIDs := make([]*string, len(describeOut.ContainerInstances))
	instanceIDToContainerDefArn := make(map[string]string)
	for i, containerInstance := range describeOut.ContainerInstances {
		instanceIDs[i] = containerInstance.Ec2InstanceId
		instanceIDToContainerDefArn[*containerInstance.Ec2InstanceId] = *containerInstance.ContainerInstanceArn
	}

	describeInstancesIn := &ec2.DescribeInstancesInput{
		InstanceIds: instanceIDs,
	}
	describeInstancesOut, err := ec2Client.DescribeInstances(describeInstancesIn)
	if err != nil {
		return nil, err
	}

	// Build final map of containerInstanceARN to host private IP.
	containerInstanceArnToIP := make(map[string]string)
	for _, reservation := range describeInstancesOut.Reservations {
		for _, instance := range reservation.Instances {
			containerInstanceArn := instanceIDToContainerDefArn[*instance.InstanceId]
			containerInstanceArnToIP[containerInstanceArn] = *instance.PrivateIpAddress
		}
	}

	return containerInstanceArnToIP, nil
}

// getStatsiteHosts finds all tasks running with the provided task definition, and returns a list of host port strings for each task of form "{host_private_ip}:{host_port_for_task}".
func getStatsiteHosts(ecsClient *ecs.ECS, ec2Client *ec2.EC2, conf *ServiceConfig, taskDefArn string) ([]string, error) {
	allTasks, err := getTasks(ecsClient, conf)
	if err != nil {
		return nil, err
	}

	// Filter to tasks that are running the specified task definition.
	tasks := make([]*ecs.Task, 0, len(allTasks))
	for _, t := range allTasks {
		if *t.TaskDefinitionArn == taskDefArn {
			tasks = append(tasks, t)
		}
	}

	containerInstanceIPs, err := getContainerInstanceIPs(ecsClient, ec2Client, conf)
	if err != nil {
		return nil, err
	}

	statsiteHosts := make([]string, len(tasks))
	for i, task := range tasks {
		if len(task.Containers) > 1 {
			return nil, errors.New("only expected one container per statsite task")
		}

		containerIP, ok := containerInstanceIPs[*task.ContainerInstanceArn]
		if !ok {
			return nil, fmt.Errorf("container instance %s not found in statsite cluster", *task.ContainerInstanceArn)
		}
		networkBindings := task.Containers[0].NetworkBindings
		for _, binding := range networkBindings {
			// StatsdProxy sends via udp, even though tcp is also available.
			if *binding.Protocol == "udp" {
				fullHost := fmt.Sprintf("%s:%d", containerIP, int(*binding.HostPort))
				statsiteHosts[i] = fullHost
			}
		}
	}

	return statsiteHosts, nil
}
