package main

import (
	"bytes"
	"encoding/json"
	"fmt"
	"io/ioutil"
	"net/http"
	"sort"
	"strings"
	"text/template"

	"context"
	"regexp"
	"time"

	"code.justin.tv/feeds/errors"
	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
	"github.com/aws/aws-sdk-go/service/ecs"
	"github.com/aws/aws-sdk-go/service/s3"
	"github.com/hashicorp/terraform/terraform"
	"github.com/urfave/cli"
)

var taskRegex = regexp.MustCompile(`\(task ([a-z0-9-]+)\)`)

func findTasksInMessage(msg string) []string {
	tasks := taskRegex.FindAllStringSubmatch(msg, -1)
	ret := make([]string, 0, len(tasks))
	for _, taskPart := range tasks {
		task := taskPart[1]
		ret = append(ret, task)
	}
	return ret
}

func getBody(resp *http.Response) string {
	body, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return fmt.Sprintf("err: %s", err.Error())
	}
	return string(body)
}

func errExit(f func(c *cli.Context) error) func(c *cli.Context) error {
	return func(c *cli.Context) error {
		err := f(c)
		if err == nil {
			return nil
		}
		return cli.NewExitError(err.Error(), 1)
	}
}

var envToIndex = map[string]int{
	"latest":      1,
	"integration": 2,
	"staging":     3,
	"canary":      4,
	"production":  5,
}

func sortEnvs(envs []string, err error) ([]string, error) {
	if err != nil {
		return envs, err
	}
	sort.Slice(envs, func(i, j int) bool {
		i1 := envToIndex[envs[i]]
		i2 := envToIndex[envs[j]]
		return i1 < i2
	})
	return envs, err
}

func uniq(in []string) []string {
	u := make(map[string]struct{}, len(in))
	ret := make([]string, 0, len(in))
	for _, s := range in {
		if _, exists := u[s]; !exists {
			u[s] = struct{}{}
			ret = append(ret, s)
		}
	}
	return ret
}

func trimmeKeys(prefix []string, keys []string) []string {
	ret := make([]string, 0, len(keys))
	for _, k := range keys {
		k = strings.TrimSuffix(k, "/")
		added := false
		for _, p := range prefix {
			trimmed := strings.TrimPrefix(k, p)
			if trimmed != k {
				if trimmed != "" {
					ret = append(ret, trimmed)
				}
				added = true
				break
			}
		}
		if !added {
			ret = append(ret, k)
		}
	}
	return ret
}

func dockerImageExists(c *cli.Context, team string, service string, tag string) error {
	dockerHost := c.GlobalString("docker")
	if dockerHost == "" {
		dockerHost = c.String("docker")
	}
	req, err := http.NewRequest(http.MethodHead, fmt.Sprintf("%s/artifactory/docker/%s/%s/%s", dockerHost, team, service, tag), nil)
	if err != nil {
		return err
	}
	client := http.Client{}
	resp, err := client.Do(req)
	if err != nil {
		return err
	}
	if resp.StatusCode == http.StatusNotFound {
		return errors.Errorf("MISSING image %s/%s:%s", team, service, tag)
	}
	if resp.StatusCode == http.StatusOK {
		return nil
	}
	return errors.Errorf("Unexpected status code %d", resp.StatusCode)
}

func createECSTask(c *cli.Context, deploymentInfo *deploymentInfo, awsapi *awsapi, logger *logger, templateInfo map[string]string) (*ecs.TaskDefinition, error) {
	taskInput, err := createTaskDefinitionInput(c, deploymentInfo, awsapi, logger, templateInfo)
	if err != nil {
		return nil, err
	}
	logger.verbose(c, "task input created", taskInput)
	roleToUse := deploymentInfo.getRole(c)
	ecsClient, err := awsapi.getECSClient(deploymentInfo.Region, roleToUse, deploymentInfo.Profile)
	if err != nil {
		return nil, err
	}
	req, out2 := ecsClient.RegisterTaskDefinitionRequest(taskInput)
	if err := req.Send(); err != nil {
		return nil, err
	}
	return out2.TaskDefinition, nil
}

func createTaskDefinitionInput(c *cli.Context, deploymentInfo *deploymentInfo, awsapi *awsapi, logger *logger, templateInfo map[string]string) (*ecs.RegisterTaskDefinitionInput, error) {
	templateParams := map[string]interface{}{}

	containerTemplate := deploymentInfo.TaskTemplate
	if containerTemplate == "" {
		containerTemplate = defaultTaskTemplateText
	}

	currentState, err := deploymentInfo.readTerraformState(c, awsapi)
	if err != nil {
		return nil, err
	}
	for k, v := range templateInfo {
		templateParams[k] = v
	}
	for k, v := range currentState {
		templateParams[k] = v
	}

	tmplate, err := template.New("containers").Parse(containerTemplate)
	if err != nil {
		return nil, errors.Wrapf(err, "Unable to parse containers template %s", containerTemplate)
	}
	buf := bytes.Buffer{}
	if err := tmplate.Execute(&buf, templateParams); err != nil {
		return nil, errors.Wrap(err, "Unable to process template")
	}
	logger.verbose(c, "template created", buf.String(), templateParams)

	abc := ecs.RegisterTaskDefinitionInput{}
	if err := json.NewDecoder(&buf).Decode(&abc); err != nil {
		return nil, errors.Wrap(err, "Unable to read container template file as containers")
	}
	return &abc, nil
}

func readTerraformState(s3client *s3.S3, bucket string, key string) (map[string]string, error) {
	out, err := s3client.GetObject(&s3.GetObjectInput{
		Bucket: &bucket,
		Key:    &key,
	})
	if err != nil {
		return nil, err
	}
	state := terraform.State{}
	if err := json.NewDecoder(out.Body).Decode(&state); err != nil {
		return nil, err
	}
	terraformOutputs := make(map[string]string, len(state.RootModule().Outputs))
	for k, v := range state.RootModule().Outputs {
		terraformOutputs[k], _ = v.Value.(string)
	}
	return terraformOutputs, nil
}

func deployTask(c *cli.Context, task string, awsapi *awsapi, deploymentInfo *deploymentInfo) error {
	ecsClient, err := awsapi.getECSClient(deploymentInfo.Region, deploymentInfo.getRole(c), deploymentInfo.Profile)
	if err != nil {
		return err
	}
	clusterName := deploymentInfo.ClusterName
	serviceName := deploymentInfo.ServiceName
	updateReq := ecs.UpdateServiceInput{
		Cluster:        &clusterName,
		Service:        &serviceName,
		TaskDefinition: &task,
	}
	req, _ := ecsClient.UpdateServiceRequest(&updateReq)
	return req.Send()
}

func cleanUpTasks(c *cli.Context, awsapi *awsapi, deploymentInfo *deploymentInfo, taskFamily string) ([]string, error) {
	ecsClient, err := awsapi.getECSClient(deploymentInfo.Region, deploymentInfo.getRole(c), deploymentInfo.Profile)
	if err != nil {
		return nil, err
	}
	if err != nil {
		return nil, err
	}
	cleaner := CleanupTasks{
		ECSClient:   ecsClient,
		TasksToKeep: 20,
		Dryrun:      false,
	}
	removed, err := cleaner.Run(taskFamily)
	if err != nil {
		return nil, err
	}
	return removed, nil
}

func parseAllTasks(service *ecs.Service) []*string {
	allTasks := make([]*string, 0, 100)
	for _, event := range service.Events {
		tasks := findTasksInMessage(*event.Message)
		for _, task := range tasks {
			allTasks = append(allTasks, &task)
		}
		if len(allTasks) > 100 {
			allTasks = allTasks[:100]
			break
		}
	}
	return allTasks
}

func findATask(ctx context.Context, ecsClient *ecs.ECS, clusterName string, serviceName string, taskDefinition *ecs.TaskDefinition) (*ecs.Task, error) {
	for {
		reqInfo := ecs.DescribeServicesInput{
			Cluster:  &clusterName,
			Services: []*string{&serviceName},
		}
		out, err := ecsClient.DescribeServicesWithContext(ctx, &reqInfo)
		if err != nil {
			return nil, err
		}
		if len(out.Services) != 1 {
			continue
		}
		allTasks := parseAllTasks(out.Services[0])
		out2, err := ecsClient.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
			Cluster: &clusterName,
			Tasks:   allTasks,
		})
		if err != nil {
			return nil, err
		}
		for _, task := range out2.Tasks {
			if *task.TaskDefinitionArn == *taskDefinition.TaskDefinitionArn {
				return task, nil
			}
		}
		select {
		case <-ctx.Done():
			return nil, ctx.Err()
		case <-time.After(time.Second * 5):
		}
	}
}

type taskInformation struct {
	*ecs.Task
	ecsClient *ecs.ECS
	logClient *cloudwatchlogs.CloudWatchLogs
}

func (t *taskInformation) TaskARN() string {
	return *t.Task.TaskArn
}

func (t *taskInformation) Containers() []*ecs.Container {
	return t.Task.Containers
}

func (t *taskInformation) Status() string {
	if t.Task.StoppedAt == nil || *t.Task.StoppedReason == "" {
		return "Running"
	}
	return *t.Task.StoppedReason
}

func keyVal(in map[string]*string, key string) string {
	k, v := in[key]
	if !v {
		return ""
	}
	if k == nil {
		return ""
	}
	return *k
}

func taskID(s string) string {
	v := strings.Split(s, "/")
	if len(v) == 2 {
		return v[1]
	}
	return s
}

func combineMessageOutput(out *cloudwatchlogs.GetLogEventsOutput) string {
	totalMsg := ""
	for _, event := range out.Events {
		totalMsg += *event.Message
		if len(totalMsg) > 500 {
			break
		}
	}
	return totalMsg
}

func (t *taskInformation) ContainerLogs(index int) string {
	def, err := t.ecsClient.DescribeTaskDefinition(&ecs.DescribeTaskDefinitionInput{
		TaskDefinition: t.Task.TaskDefinitionArn,
	})
	if err != nil {
		return fmt.Sprintf("<ERROR getting task definition: %s>", err.Error())
	}
	if len(def.TaskDefinition.ContainerDefinitions) <= index {
		return fmt.Sprintf("<container len too small: %d", len(def.TaskDefinition.ContainerDefinitions))
	}
	containerDef := def.TaskDefinition.ContainerDefinitions[index]
	if containerDef.LogConfiguration == nil {
		return "<No logs>"
	}
	if *containerDef.LogConfiguration.LogDriver != "awslogs" {
		return fmt.Sprintf("Unable to parse log type %s", *containerDef.LogConfiguration.LogDriver)
	}
	region := keyVal(containerDef.LogConfiguration.Options, "awslogs-region")
	group := keyVal(containerDef.LogConfiguration.Options, "awslogs-group")
	prefix := keyVal(containerDef.LogConfiguration.Options, "awslogs-stream-prefix")
	if region == "" || group == "" {
		return fmt.Sprintf("Invalid config <%s> <%s>", region, group)
	}
	if region != *t.ecsClient.Config.Region {
		return fmt.Sprintf("Unable to get config out of another region: %s", region)
	}
	streamName := fmt.Sprintf("%s/%s/%s", prefix, *containerDef.Name, taskID(*t.Task.TaskArn))
	out, err := t.logClient.GetLogEvents(&cloudwatchlogs.GetLogEventsInput{
		LogGroupName:  &group,
		LogStreamName: &streamName,
	})
	if err != nil {
		if strings.Contains(err.Error(), "AccessDeniedException") {
			return fmt.Sprintf("<Permission denied to get awslog output for %s>", streamName)
		}
		if strings.Contains(err.Error(), "ResourceNotFoundException") {
			return fmt.Sprintf("<Log stream %s not yet created>", streamName)
		}
		return fmt.Sprintf("Unable to get log events for %s: %s", streamName, err.Error())
	}
	return combineMessageOutput(out)
}

func backgroundLogTaskInformation(ctx context.Context, c *cli.Context, task *ecs.Task, ecsClient *ecs.ECS, logClient *cloudwatchlogs.CloudWatchLogs) error {
	prevOutput := ""
	for {
		ti := taskInformation{
			Task:      task,
			ecsClient: ecsClient,
			logClient: logClient,
		}
		buf := bytes.Buffer{}
		if err := runningTaskTemplate.Execute(&buf, &ti); err != nil {
			return err
		}
		if prevOutput != buf.String() {
			fmt.Fprintln(c.App.Writer, buf.String())
		}
		prevOutput = buf.String()
		if task.StoppedReason != nil && *task.StoppedReason != "" {
			return errors.New("Task is invalid")
		}
		select {
		case <-ctx.Done():
			return ctx.Err()
		case <-time.After(time.Second * 20):
		}
		taskOut, err := ecsClient.DescribeTasks(&ecs.DescribeTasksInput{
			Cluster: task.ClusterArn,
			Tasks:   []*string{task.TaskArn},
		})
		if err != nil {
			return err
		}
		if len(taskOut.Tasks) != 1 {
			return errors.New("Did not find the right number of tasks back out")
		}
		task = taskOut.Tasks[0]
	}
}

func blockTillStable(c *cli.Context, awsapi *awsapi, deploymentInfo *deploymentInfo, taskDefinition *ecs.TaskDefinition) error {
	roleToUse := deploymentInfo.getRole(c)
	ecsClient, err := awsapi.getECSClient(deploymentInfo.Region, roleToUse, deploymentInfo.Profile)
	if err != nil {
		return err
	}
	logClient, err := awsapi.getLogClient(deploymentInfo.Region, roleToUse, deploymentInfo.Profile)
	if err != nil {
		return err
	}
	clusterName := deploymentInfo.ClusterName
	serviceName := deploymentInfo.ServiceName
	reqInfo := ecs.DescribeServicesInput{
		Cluster:  &clusterName,
		Services: []*string{&serviceName},
	}

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	go func() {
	getTask:
		task, err := findATask(ctx, ecsClient, clusterName, serviceName, taskDefinition)
		if err != nil {
			if ctx.Err() != nil {
				return
			}
			time.Sleep(time.Second)
			goto getTask
			//if ctx.Err() == nil {
			//	fmt.Fprintln(c.App.ErrWriter, "Error getting task name", err.Error())
			//}
			//return
		}
		fmt.Fprintln(c.App.Writer, "Found deployed task", *task.TaskArn)
		err2 := backgroundLogTaskInformation(ctx, c, task, ecsClient, logClient)
		if err2 != nil && ctx.Err() == nil {
			fmt.Fprintln(c.App.ErrWriter, "Error logging task info", err2.Error())
		}
	}()

	return ecsClient.WaitUntilServicesStableWithContext(ctx, &reqInfo)
}
