package main

import (
	"fmt"
	"log"
	"net/http"
	"time"

	"github.com/alexflint/go-arg"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/ecs"
)

type Args struct {
	EnvName  string `arg:"positional" help:"Target environment: [dev,prod]"`
	FirstSvc string `arg:"-s" default:"pathfinder" help:"First service to restart: [greeter,pathfinder,threshold,source]. Services to the right are also restarted in order."`
}

type envConf struct {
	AwsProfile string
	AwsAccount string
	Services   map[string]svcConf
}

type svcConf struct {
	Cluster        string
	Service        string
	RestartTimeout time.Duration
}

func main() {
	var args Args
	arg.MustParse(&args)
	var conf envConf
	if args.EnvName == "prod" {
		conf = envConf{
			AwsProfile: "twitch-eml-prod",
			AwsAccount: "342135511598",
			Services: map[string]svcConf{
				"greeter": svcConf{
					Cluster:        "ProdGreeterService-ClusterEB0386A7-14F243CYT8VPN",
					Service:        "ProdGreeterService-FargateService7B4DE80D-1GD0WR6GD96YC",
					RestartTimeout: 10 * time.Minute,
				},
				"pathfinder": svcConf{
					Cluster:        "ProdPathfinderService-ClusterEB0386A7-1I254F53RW2S0",
					Service:        "ProdPathfinderService-FargateService7B4DE80D-1X0Z9MHKGZWLY",
					RestartTimeout: 10 * time.Minute,
				},
				"threshold": svcConf{
					Cluster:        "ProdThresholdService-ClusterEB0386A7-28PXMIVABFVD",
					Service:        "ProdThresholdService-EC2ServiceF0CE72D0-NF0L7V6QN2OJ",
					RestartTimeout: 20 * time.Minute,
				},
				"source": svcConf{
					Cluster:        "ProdSourceService-ClusterEB0386A7-M3EP3Q356332",
					Service:        "ProdSourceService-EC2ServiceF0CE72D0-S09TNE632MCM",
					RestartTimeout: 10 * time.Minute,
				},
			},
		}
	} else if args.EnvName == "dev" {
		conf = envConf{
			AwsProfile: "twitch-eml-dev",
			AwsAccount: "565915620853",
			Services: map[string]svcConf{
				"greeter": svcConf{
					Cluster:        "DevGreeterService-ClusterEB0386A7-QKRWL0HD6EC1",
					Service:        "DevGreeterService-FargateService7B4DE80D-OS2NPGSKPIU6",
					RestartTimeout: 10 * time.Minute,
				},
				"pathfinder": svcConf{
					Cluster:        "DevPathfinderService-ClusterEB0386A7-MAPA266127HT",
					Service:        "DevPathfinderService-FargateService7B4DE80D-J4DHKNQEYVGG",
					RestartTimeout: 10 * time.Minute,
				},
				"threshold": svcConf{
					Cluster:        "DevThresholdService-ClusterEB0386A7-dogaofKHJ01C",
					Service:        "DevThresholdService-EC2ServiceF0CE72D0-1QBFMZ101G191",
					RestartTimeout: 20 * time.Minute,
				},
				"source": svcConf{
					Cluster:        "DevSourceService-ClusterEB0386A7-lC6I6ps5FNYy",
					Service:        "DevSourceService-EC2ServiceF0CE72D0-1PC2N4NKVIKGL",
					RestartTimeout: 10 * time.Minute,
				},
			},
		}
	} else {
		log.Fatalf("Invalid envName: %q. Use -h for help", args.EnvName)
	}

	// Validate FirstSvc service name if provided
	if !(args.FirstSvc == "greeter" || args.FirstSvc == "pathfinder" || args.FirstSvc == "threshold" || args.FirstSvc == "source") {
		log.Fatalf("Invalid service name: %q. Use -h for help", args.FirstSvc)
	}

	// Initialize AWS client
	sess := session.Must(session.NewSessionWithOptions(session.Options{
		Config: aws.Config{
			Region:      aws.String("us-west-2"),
			Credentials: credentials.NewSharedCredentials("", conf.AwsProfile),
			HTTPClient:  &http.Client{Timeout: 10 * time.Second},
		},
	}))
	ecsCli := ecs.New(sess)

	// Restart services in order, starting from args.FirstSvc
	found := false
	for _, svc := range []string{"greeter", "pathfinder", "threshold", "source"} {
		if svc != args.FirstSvc && !found {
			continue
		}
		found = true

		svcConf := conf.Services[svc]
		cluster := aws.String(svcConf.Cluster)
		service := aws.String(svcConf.Service)
		log.Printf("Restarting %s (%+v)\n", svc, svcConf)

		// Describe service
		services, err := ecsCli.DescribeServices(&ecs.DescribeServicesInput{
			Cluster:  cluster,
			Services: []*string{service},
		})
		mustNoError(err)
		if len(services.Failures) > 0 {
			log.Fatalf("Failed to describe services for %s: %s", svc, services)
		}
		if len(services.Services) != 1 {
			log.Fatalf("Invalid number of services for %s. Expected 1, found %d.\n%s", svc, len(services.Services), services)
		}
		ecsSvc := services.Services[0]

		// Check if service is ready to restart
		if *ecsSvc.Status != "ACTIVE" {
			log.Fatalf("Error: Service %s status is not ACTIVE.", svc)
		}
		desiredCount := *ecsSvc.DesiredCount
		runningCount := *ecsSvc.RunningCount
		pendingCount := *ecsSvc.PendingCount
		if desiredCount != runningCount || pendingCount != 0 {
			log.Fatalf("Error: Service %s is already being deployed or restarting. Please try again later. {desired: %d, running: %d, pending: %d}", svc, desiredCount, runningCount, pendingCount)
		}

		// We could restart with zero-downtime by using something like `aws ecs update-service --force-new-deployment` like the ecs-restart.sh script,
		// but that is too slow (E2ML services take very long to shut down for some reason), and when we need to restart is because we are down already.
		// Instead of that, force-stop all the tasks on the service, and wait until it comes back.

		tasks, err := ecsCli.ListTasks(&ecs.ListTasksInput{
			Cluster:     cluster,
			ServiceName: service,
		})
		mustNoError(err)
		log.Printf("Stop %d tasks in %s\n", len(tasks.TaskArns), svc)

		for _, taskARN := range tasks.TaskArns {
			_, err := ecsCli.StopTask(&ecs.StopTaskInput{
				Cluster: cluster,
				Task:    taskARN,
				Reason:  aws.String("Hard restart from script"),
			})
			mustNoError(err)
		}
		log.Printf("Wait until %d tasks restart in %s\n", len(tasks.TaskArns), svc)

		// Wait until new tasks are started on this service before moving to the next one
		start := time.Now()
		for {
			// Wait a few seconds ...
			if time.Since(start) > svcConf.RestartTimeout {
				log.Fatalf("Timeout Error: Service %s is taking too long to restart.", svc)
			}
			time.Sleep(10 * time.Second)
			fmt.Printf(".")

			// Check if restart is done
			services, err := ecsCli.DescribeServices(&ecs.DescribeServicesInput{
				Cluster:  cluster,
				Services: []*string{service},
			})
			mustNoError(err)
			ecsSvc = services.Services[0]
			desiredCount := *ecsSvc.DesiredCount
			runningCount := *ecsSvc.RunningCount
			pendingCount := *ecsSvc.PendingCount

			if runningCount == desiredCount && pendingCount == 0 {
				fmt.Printf(" Restarted in %s\n", time.Since(start))
				break // restart done, move to next service
			}
		}
	}

	log.Printf("DONE")
}

func mustNoError(err error) {
	if err != nil {
		log.Fatalf("Error: %v", err)
	}
}
