package main

import (
	"fmt"
	"log"
	"os"

	"strings"

	"time"

	"code.justin.tv/feeds/errors"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/ecs"
	"github.com/aws/aws-sdk-go/service/sts"
)

const (
	EnvIntegration = "integration"
	EnvStaging     = "staging"
	EnvCanary      = "canary"
	EnvProduction  = "production"
)

func main() {
	config, err := loadConfig()
	if err != nil {
		log.Fatal(err)
	}
	log.Println("Config")
	log.Println(fmt.Sprintf("Environment: %s", config.env))
	log.Println(fmt.Sprintf("Git Commit: %s", config.gitCommit))

	// Create an ECS client.
	ecsClient, err := newEcsClient(config)
	if err != nil {
		log.Println("Creating ECS client failed")
		log.Fatal(err)
	}

	// Register a new task definition that points to the new docker image.
	log.Println("Registering task definition.")
	taskDefOutput, err := registerTaskDefinition(ecsClient, config)
	if err != nil {
		log.Println("Registering task definition failed.")
		log.Fatal(err)
	}
	log.Println("Registering task definition succeeded.")
	log.Println(taskDefOutput.String())

	// Update Gea's ECS service to use the new task definition.
	log.Println("Updating ECS service.")
	updateServiceOuput, err := updateService(ecsClient, config, taskDefOutput.TaskDefinition.TaskDefinitionArn)
	if err != nil {
		log.Println("Updating ECS service failed.")
		log.Fatal(err)
	}
	log.Println("Updated ECS service succeeded.")
	log.Println(updateServiceOuput.String())
}

type config struct {
	env              string
	gitCommit        string
	awsConfigFile    string
	awsProfile       string
	awsAssumeRoleARN string
}

func loadConfig() (*config, error) {
	env := os.Getenv("ENVIRONMENT")
	env = strings.ToLower(env)
	ok := false
	for _, knownEnv := range []string{EnvIntegration, EnvStaging, EnvCanary, EnvProduction} {
		if env == knownEnv {
			ok = true
			break
		}
	}
	if !ok {
		return nil, errors.Errorf("ENVIRONMENT env variable was not set to %s, %s, %s, or %s", EnvIntegration, EnvStaging, EnvCanary, EnvProduction)
	}

	gitCommit, err := loadEnvVariable("GIT_COMMIT")
	if err != nil {
		return nil, err
	}

	awsConfigFile, err := loadEnvVariable("AWS_CONFIG_FILE")
	if err != nil {
		return nil, err
	}

	awsProfile, err := loadEnvVariable("AWS_PROFILE")
	if err != nil {
		return nil, err
	}

	awsAssumeRoleARN, err := loadEnvVariable("AWS_ASSUME_ROLE_ARN")
	if err != nil {
		return nil, err
	}

	return &config{
		env:              env,
		gitCommit:        gitCommit,
		awsConfigFile:    awsConfigFile,
		awsProfile:       awsProfile,
		awsAssumeRoleARN: awsAssumeRoleARN,
	}, nil
}

func loadEnvVariable(key string) (string, error) {
	value := os.Getenv(key)
	if value == "" {
		return "", errors.Errorf("%s env variable was not set", key)
	}
	return value, nil
}

func newEcsClient(conf *config) (*ecs.ECS, error) {
	// Create an AWS session
	// NB: We explicitly set the AWS config file and profile because otherwise it uses the default credentials
	// provided by the Jenkins environment.
	iamUserCreds := credentials.NewSharedCredentials(conf.awsConfigFile, conf.awsProfile)
	sessionConfig := aws.NewConfig().WithRegion("us-west-2").WithCredentials(iamUserCreds)
	awsSession, err := session.NewSession(sessionConfig)
	if err != nil {
		return nil, err
	}

	// Create credentials that allow us to assume Gea's build role that has permission to update ECS.
	stsClient := sts.New(awsSession)
	arp := &stscreds.AssumeRoleProvider{
		ExpiryWindow: 10 * time.Second,
		RoleARN:      conf.awsAssumeRoleARN,
		Client:       stsClient,
	}
	confWithAssumeRoleCreds := &aws.Config{
		Credentials: credentials.NewCredentials(arp),
	}

	return ecs.New(awsSession, confWithAssumeRoleCreds), nil
}

func registerTaskDefinition(ecsClient *ecs.ECS, conf *config) (*ecs.RegisterTaskDefinitionOutput, error) {
	var awsAccountID string
	switch conf.env {
	case EnvProduction, EnvCanary:
		awsAccountID = "914569885343"
	case EnvStaging, EnvIntegration:
		awsAccountID = "724951484461"
	default:
		return nil, fmt.Errorf("could not find params for env \"%s\"", conf.env)
	}

	taskRoleArn := fmt.Sprintf("arn:aws:iam::%s:role/events-gea-%s", awsAccountID, conf.env)
	family := fmt.Sprintf("events-gea-%s", conf.env)
	awsLogsGroup := fmt.Sprintf("events-%s-container-logs", conf.env)
	awsLogsStreamPrefix := conf.env
	image := fmt.Sprintf("docker-registry.internal.justin.tv/twitch-events/gea:%s", conf.gitCommit)

	input := &ecs.RegisterTaskDefinitionInput{
		ContainerDefinitions: []*ecs.ContainerDefinition{
			{
				Cpu: aws.Int64(1024),
				Environment: []*ecs.KeyValuePair{
					{
						Name:  aws.String("ENVIRONMENT"),
						Value: aws.String(conf.env),
					}, {
						Name:  aws.String("GIT_COMMIT"),
						Value: aws.String(conf.gitCommit),
					},
				},
				Essential: aws.Bool(true),
				Image:     aws.String(image),
				LogConfiguration: &ecs.LogConfiguration{
					LogDriver: aws.String("awslogs"),
					Options: map[string]*string{
						"awslogs-group":         aws.String(awsLogsGroup),
						"awslogs-region":        aws.String("us-west-2"),
						"awslogs-stream-prefix": aws.String(awsLogsStreamPrefix),
					},
				},
				MemoryReservation: aws.Int64(512),
				Name:              aws.String("gea"),
				PortMappings: []*ecs.PortMapping{
					{
						ContainerPort: aws.Int64(8080),
						HostPort:      aws.Int64(0),
						Protocol:      aws.String("tcp"),
					}, {
						ContainerPort: aws.Int64(6060),
						HostPort:      aws.Int64(0),
						Protocol:      aws.String("tcp"),
					},
				},
				Ulimits: []*ecs.Ulimit{
					{
						HardLimit: aws.Int64(10000),
						Name:      aws.String("nofile"),
						SoftLimit: aws.Int64(10000),
					},
				},
			},
		},
		Family:      aws.String(family),
		NetworkMode: aws.String("bridge"),
		TaskRoleArn: aws.String(taskRoleArn),
	}

	log.Println("RegisterTaskDefinitionInput")
	log.Println(input.String())

	return ecsClient.RegisterTaskDefinition(input)
}

func updateService(ecsClient *ecs.ECS, conf *config, taskDefinitionArn *string) (*ecs.UpdateServiceOutput, error) {
	cluster := fmt.Sprintf("events-%s-common", conf.env)
	service := "gea"

	s, err := ecsClient.UpdateService(&ecs.UpdateServiceInput{
		Cluster:            aws.String(cluster),
		Service:            aws.String(service),
		TaskDefinition:     taskDefinitionArn,
		ForceNewDeployment: aws.Bool(true),
	})
	if err != nil {
		return nil, err
	}

	log.Println("Waiting on ECS service...")
	err = ecsClient.WaitUntilServicesStable(&ecs.DescribeServicesInput{
		Cluster: aws.String(cluster),
		Services: []*string{
			aws.String(service),
		},
	})

	return s, err
}
