// Package ecs enables fetching of metadata concerning where an ECS-based application is running.
package ecs

import (
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go/aws/ec2metadata"
)

// Getter is the minimal interface needed to fetch EC2 instance metadata.
type Getter interface {
	GetInstanceIdentityDocument() (ec2metadata.EC2InstanceIdentityDocument, error)
}

// Metadata contains information about where this application is running within an AWS ECS environment.
type Metadata struct {
	// InstanceID is the identifying portion of the container instance's EC2 ARN.
	InstanceID string
	// TaskID is the identifying portion of the container's Task ARN.
	TaskID string
	// Region is the AWS region the instance is running in.
	Region string
}

// GetMetadata from the ECS Task Metadata Endpoint and the ec2metadata API.
//
// The uri is expected to be the value pulled from the ECS_CONTAINER_METADATA_URI_V4 environment
// variable, which is set by the ECS container agent.
func GetMetadata(uri string, ec2 Getter) (*Metadata, error) {
	taskID, err := getTaskID(uri)
	if err != nil {
		return nil, err
	}

	doc, err := ec2.GetInstanceIdentityDocument()
	if err != nil {
		return nil, err
	}

	return &Metadata{
		InstanceID: doc.InstanceID,
		TaskID:     taskID,
		Region:     doc.Region,
	}, nil
}

// Get the container's task ID using the Task Metadata Endpoint.
// https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-metadata-endpoint-v4.html
func getTaskID(uri string) (string, error) {
	if uri == "" {
		return "", errors.New("expected non-empty task metadata URI")
	}

	client := http.Client{
		Timeout: 5 * time.Second,
	}

	resp, err := client.Get(uri + "/task")
	if err != nil {
		return "", err
	}

	defer resp.Body.Close()
	b, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return "", err
	}

	if resp.StatusCode != http.StatusOK {
		return "", fmt.Errorf("unexpected task metadata response: %s", resp.Status)
	}

	var t task
	if err := json.Unmarshal(b, &t); err != nil {
		return "", err
	}

	return t.id(), nil
}

type task struct {
	ARN string `json:"TaskARN"`
}

func (t task) id() string {
	parts := strings.Split(t.ARN, "/")
	if len(parts) == 0 {
		return ""
	}

	return parts[len(parts)-1]
}
