package main

import (
	"context"
	"encoding/json"
	"fmt"
	"log"
	"os"

	payload "code.justin.tv/cb/achievements/validator/model/lambda"
	"github.com/aws/aws-lambda-go/lambda"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/ecs"
)

func HandleRequest(ctx context.Context, params payload.Payload) error {
	queryName := params.QueryName
	log.Print(queryName)

	sess := session.Must(session.NewSession(&aws.Config{
		Region: aws.String(os.Getenv("REGION")),
	}))

	subnetIDs, err := parseSubnetIDs(os.Getenv("SUBNET_IDS"))
	if err != nil {
		return fmt.Errorf("parse SUBNET_IDS: %w", err)
	}

	ecsClient := ecs.New(sess)
	_, err = ecsClient.RunTaskWithContext(context.Background(), &ecs.RunTaskInput{
		Cluster:    aws.String(os.Getenv("CLUSTER_ARN")),
		Count:      aws.Int64(1),
		LaunchType: aws.String(ecs.LaunchTypeFargate),
		NetworkConfiguration: &ecs.NetworkConfiguration{
			AwsvpcConfiguration: &ecs.AwsVpcConfiguration{
				Subnets: subnetIDs,
			},
		},
		Overrides: &ecs.TaskOverride{
			ContainerOverrides: []*ecs.ContainerOverride{
				{
					Name: aws.String(os.Getenv("CONTAINER_NAME")),
					Environment: []*ecs.KeyValuePair{
						{
							Name:  aws.String("QUERY_NAME"),
							Value: aws.String(queryName),
						},
					},
				},
			},
		},
		TaskDefinition: aws.String(os.Getenv("TASK_DEFINITION")),
	})
	if err != nil {
		return fmt.Errorf("failed to run task: %w", err)
	}

	return nil
}

func parseSubnetIDs(ids string) ([]*string, error) {
	var subnetIDs []string
	err := json.Unmarshal([]byte(ids), &subnetIDs)
	if err != nil {
		return nil, err
	}

	awsStrings := make([]*string, 0, len(subnetIDs))
	for _, id := range subnetIDs {
		awsStrings = append(awsStrings, aws.String(id))
	}

	return awsStrings, nil
}

func main() {
	lambda.Start(HandleRequest)
}
