package main

import (
	"context"
	"flag"
	"fmt"
	"io"
	"log"
	"os"
	"strings"
	"sync"
	"time"

	"golang.org/x/sync/errgroup"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
	"github.com/aws/aws-sdk-go/service/ecs"
)

func main() {
	var (
		region  = flag.String("region", "us-west-2", "AWS region")
		cluster = flag.String("cluster", "ecs_dev", "ECS cluster name")
		family  = flag.String("family", "xenial_command", "ECS task family")
	)
	flag.Parse()
	argv := flag.Args()

	ctx := context.Background()
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	conf := aws.NewConfig().WithRegion(*region)
	sess := session.Must(session.NewSession(conf))
	ec := ecs.New(sess)
	logs := cloudwatchlogs.New(sess)

	t := &thing{
		cluster: *cluster,
		family:  *family,
		argv:    argv,
	}

	exitCode, err := t.do(ctx, ec, logs)
	if err != nil {
		log.Fatalf("err = %v", err)
	}
	os.Exit(exitCode)
}

type thing struct {
	cluster string
	family  string
	argv    []string
}

func (t *thing) do(ctx context.Context, ec *ecs.ECS, logs *cloudwatchlogs.CloudWatchLogs) (int, error) {
	runResp, err := ec.RunTaskWithContext(ctx, &ecs.RunTaskInput{
		Cluster:        aws.String(t.cluster),
		Count:          aws.Int64(1),
		TaskDefinition: aws.String(t.family),
		Overrides: &ecs.TaskOverride{
			ContainerOverrides: []*ecs.ContainerOverride{
				&ecs.ContainerOverride{
					Name:    aws.String("main"),
					Command: aws.StringSlice(t.argv),
				},
			},
		},
	})
	if err != nil {
		return 0, err
	}

	var (
		exitCodeMu  sync.Mutex
		maxExitCode int
	)

	eg, egCtx := errgroup.WithContext(ctx)
	for _, fail := range runResp.Failures {
		return 0, fmt.Errorf("RunTask failure arn=%q reason=%q", aws.StringValue(fail.Arn), aws.StringValue(fail.Reason))
	}
	for _, task := range runResp.Tasks {
		taskARN := aws.StringValue(task.TaskArn)
		var taskID string
		if i := strings.LastIndex(taskARN, "/"); i >= 0 {
			taskID = taskARN[i+len("/"):]
		}

		eg.Go(func() error {
			ctx := egCtx

			streamName := fmt.Sprintf("command/main/%s", taskID)
			var logToken string
		pollTask:
			for {
				taskResp, err := ec.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
					Cluster: aws.String(t.cluster),
					Tasks:   aws.StringSlice([]string{taskARN}),
				})
				if err != nil {
					return err
				}
				for _, task := range taskResp.Tasks {
					status := aws.StringValue(task.LastStatus)
					switch status {
					case ecs.DesiredStatusStopped:
						for _, cont := range task.Containers {
							exitCode := int(aws.Int64Value(cont.ExitCode))
							exitCodeMu.Lock()
							if maxExitCode < exitCode {
								maxExitCode = exitCode
							}
							exitCodeMu.Unlock()
						}
						_, err := printLogs(ctx, os.Stdout, logs, streamName, logToken)
						return err
					case ecs.DesiredStatusPending:
						if err := sleep(ctx, time.Second); err != nil {
							return err
						}
						continue pollTask
					case ecs.DesiredStatusRunning:
					}
				}

				logToken, err = printLogs(ctx, os.Stdout, logs, streamName, logToken)
				if err != nil {
					return err
				}

				if err := sleep(ctx, time.Second); err != nil {
					return err
				}
			}
		})
	}
	err = eg.Wait()
	if err != nil {
		return 0, err
	}

	return maxExitCode, nil
}

func printLogs(ctx context.Context, w io.Writer, logs *cloudwatchlogs.CloudWatchLogs,
	streamName string, logToken string) (string, error) {

	var logReq cloudwatchlogs.GetLogEventsInput
	logReq.SetLogGroupName("ecs")
	logReq.SetLogStreamName(streamName)
	logReq.SetStartFromHead(logToken == "")
	if logToken != "" {
		logReq.SetNextToken(logToken)
	}
	logResp, err := logs.GetLogEventsWithContext(ctx, &logReq)
	if err != nil {
		return "", err
	}
	logToken = aws.StringValue(logResp.NextForwardToken)

	for _, ev := range logResp.Events {
		const rfc3339milli = "2006-01-02T15:04:05.000Z07:00"
		nanos := aws.Int64Value(ev.Timestamp) * int64(time.Millisecond/time.Nanosecond)
		t := time.Unix(0, nanos)
		ts := t.Format(rfc3339milli)
		msg := aws.StringValue(ev.Message)
		_, err := fmt.Fprintf(w, "%s %s: %s\n", ts, streamName, msg)
		if err != nil {
			return logToken, err
		}
	}

	return logToken, nil
}

func sleep(ctx context.Context, d time.Duration) error {
	t := time.NewTimer(time.Second)
	select {
	case <-ctx.Done():
		t.Stop()
		return ctx.Err()
	case <-t.C:
		return nil
	}
}
