package ondemand

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"text/template"
	"time"

	"math/rand"

	"code.justin.tv/feeds/distconf"
	"code.justin.tv/feeds/errors"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awsutil"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/ec2"
	"github.com/aws/aws-sdk-go/service/ecs"
	"github.com/aws/aws-sdk-go/service/sts"
)

type TaskCreationConfig struct {
	TaskRemovalAge *distconf.Duration
}

func (t *TaskCreationConfig) Load(dconf *distconf.Distconf) error {
	t.TaskRemovalAge = dconf.Duration("ecs-on-demand.removal_age", 3*time.Hour*24)
	return nil
}

type TaskCreation struct {
	TaskCreationConfig *TaskCreationConfig
	DeployDB           *DeployDB
	TaskDB             *TaskDB
	DockerHelper       *DockerHelper
	closeChan          chan struct{}
}

type forwardTo struct {
	host string
	port int32
}

func (t *TaskCreation) removeOldTasks(ctx context.Context, age time.Duration) ([]*RunningTask, error) {
	tasks, err := t.TaskDB.GetTasks(ctx)
	if err != nil {
		return nil, err
	}
	deletedTasks := make([]*RunningTask, 0, len(tasks))
	for _, task := range tasks {
		if time.Since(task.creationTime) < age {
			continue
		}
		if err := t.RemoveTask(ctx, task); err != nil {
			return nil, err
		}
		deletedTasks = append(deletedTasks, task)
	}
	return deletedTasks, nil
}

func (t *TaskCreation) reloadTask(ctx context.Context, r *RunningTask) (*ecs.Task, error) {
	deployStatus, err := t.DeployDB.getDeploymentInfo(r.Team, r.Service, r.Environment)
	if err != nil {
		return nil, err
	}
	ecsClient, err := NewECSClient(deployStatus.Region, deployStatus.DeployAWSRole, deployStatus.Profile)
	if err != nil {
		return nil, err
	}
	di := &ecs.DescribeTasksInput{
		Cluster: &r.TaskCluster,
		Tasks:   []*string{&r.TaskARN},
	}
	req, out := ecsClient.DescribeTasksRequest(di)
	req.SetContext(ctx)
	if err := req.Send(); err != nil {
		return nil, err
	}
	if len(out.Tasks) != 1 {
		return nil, errors.Errorf("unable to find task %s", r.TaskARN)
	}
	task := out.Tasks[0]
	return task, nil
}

func (t *TaskCreation) Setup() error {
	t.closeChan = make(chan struct{})
	return nil
}

func (t *TaskCreation) Close() error {
	close(t.closeChan)
	return nil
}

func (t *TaskCreation) Start() error {
	defer func() {
		fmt.Println("Task creation done")
	}()
	select {
	case <-t.closeChan:
		return nil
	case <-time.After(time.Duration(time.Second.Nanoseconds() * int64(rand.Intn(600)))):
	}
	for {
		select {
		case <-t.closeChan:
			return nil
		case <-time.After(time.Second * 600):
			if removedTasks, err := t.removeOldTasks(context.Background(), t.TaskCreationConfig.TaskRemovalAge.Get()); err != nil {
				fmt.Println("Error removing tasks", err)
			} else {
				if len(removedTasks) > 0 {
					fmt.Println("Removed tasks")
					for _, task := range removedTasks {
						fmt.Println(awsutil.Prettify(task))
					}
				} else {
					fmt.Println("no tasks to remove")
				}
			}
		}
	}
}

func (t *TaskCreation) CreateTask(ctx context.Context, p *parsedHost) (*forwardTo, error) {
	existingTask, err := t.TaskDB.GetTask(ctx, p.team, p.service, p.imageName)
	if err != nil {
		return nil, err
	}
	if existingTask != nil {
		return &forwardTo{
			host: existingTask.ForwardHost,
			port: existingTask.ForwardPort,
		}, nil
	}
	deployStatus, err := t.DeployDB.getDeploymentInfo(p.team, p.service, p.environment)
	if err != nil {
		return nil, err
	}
	ecsClient, err := NewECSClient(deployStatus.Region, deployStatus.DeployAWSRole, deployStatus.Profile)
	if err != nil {
		return nil, err
	}
	ec2Client, err := NewEC2Client(deployStatus.Region, deployStatus.DeployAWSRole, deployStatus.Profile)
	if err != nil {
		return nil, err
	}
	createdTask, err := t.createTask(&taskCreationInput{
		ecsClient:    ecsClient,
		service:      p.service,
		team:         p.team,
		env:          p.environment,
		tag:          p.imageName,
		taskTemplate: deployStatus.TaskTemplate,
		profile:      deployStatus.Profile,
		taskData:     deployStatus.Data,
	})
	if err != nil {
		return nil, errors.Wrap(err, "unable to create ECS task")
	}
	task, err := t.RunTask(ecsClient, p.team, deployStatus, createdTask)
	if err != nil {
		return nil, err
	}
	forwardPort, err := t.findForwardPort(ecsClient, ec2Client, deployStatus.ServiceName, task)
	if err != nil {
		return nil, err
	}
	if err := t.TaskDB.StoreTask(context.Background(), &RunningTask{
		Team:        p.team,
		Service:     p.service,
		ImageName:   p.imageName,
		TaskCluster: *task.ClusterArn,
		Environment: p.environment,
		TaskARN:     *task.TaskArn,
		Version:     1,
		ForwardHost: forwardPort.host,
		ForwardPort: forwardPort.port,
	}); err != nil {
		return nil, err
	}
	return forwardPort, nil
}

type taskCreationInput struct {
	ecsClient    *ecs.ECS
	team         string
	service      string
	env          string
	tag          string
	taskTemplate string
	profile      string
	taskData     map[string]string
}

func findForwardFromECS(service *ecs.Service, task *ecs.Task, instance *ec2.Instance) (*forwardTo, error) {
	var ret *forwardTo
	for _, lb := range service.LoadBalancers {
		for _, container := range task.Containers {
			for _, nb := range container.NetworkBindings {
				if *nb.ContainerPort == *lb.ContainerPort && *container.Name == *lb.ContainerName {
					ret = &forwardTo{
						host: *instance.PrivateIpAddress,
						port: int32(*nb.HostPort),
					}
				} else if ret == nil {
					ret = &forwardTo{
						host: *instance.PrivateIpAddress,
						port: int32(*nb.HostPort),
					}
				}
			}
		}
	}
	if ret == nil {
		return nil, errors.New("unable to find a forward location port")
	}
	return ret, nil
}

func (t *TaskCreation) findForwardPort(ecsClient *ecs.ECS, ec2Client *ec2.EC2, serviceName string, task *ecs.Task) (*forwardTo, error) {
	descContainerOut, err := ecsClient.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{
		Cluster:            task.ClusterArn,
		ContainerInstances: []*string{task.ContainerInstanceArn},
	})
	if err != nil {
		return nil, errors.Wrap(err, "cannot describe containers")
	}
	if len(descContainerOut.ContainerInstances) != 1 {
		return nil, errors.Errorf("too many container instances returned")
	}

	containerInstance := descContainerOut.ContainerInstances[0]
	descInstancesOut, err := ec2Client.DescribeInstances(&ec2.DescribeInstancesInput{
		InstanceIds: []*string{containerInstance.Ec2InstanceId},
	})
	if err != nil {
		return nil, err
	}
	if len(descInstancesOut.Reservations) != 1 {
		return nil, errors.New("unable to find reservation task is running on")
	}
	if len(descInstancesOut.Reservations[0].Instances) != 1 {
		return nil, errors.New("unable to find instance task is running on")
	}
	instance := descInstancesOut.Reservations[0].Instances[0]
	descOut, err := ecsClient.DescribeServices(&ecs.DescribeServicesInput{
		Cluster:  task.ClusterArn,
		Services: []*string{&serviceName},
	})
	if err != nil {
		return nil, errors.Wrapf(err, "cannot describe service %s in %s", serviceName, *task.ClusterArn)
	}
	if len(descOut.Services) != 1 {
		return nil, errors.Errorf("unable to find the service %s in %s", serviceName, *task.ClusterArn)
	}
	service := descOut.Services[0]
	return findForwardFromECS(service, task, instance)
}

func (t *TaskCreation) RemoveTask(ctx context.Context, r *RunningTask) error {
	deployStatus, err := t.DeployDB.getDeploymentInfo(r.Team, r.Service, r.Environment)
	if err != nil {
		return err
	}
	ecsClient, err := NewECSClient(deployStatus.Region, deployStatus.DeployAWSRole, deployStatus.Profile)
	if err != nil {
		return err
	}
	in := &ecs.StopTaskInput{
		Cluster: &r.TaskCluster,
		Reason:  aws.String("Stop in UI"),
		Task:    &r.TaskARN,
	}
	req, _ := ecsClient.StopTaskRequest(in)
	req.SetContext(ctx)
	if err := req.Send(); err != nil {
		return err
	}
	return t.TaskDB.DeleteTask(ctx, r)
}

func (t *TaskCreation) RunTask(ecsClient *ecs.ECS, team string, deployStatus *deploymentInfo, taskDefinition *ecs.TaskDefinition) (*ecs.Task, error) {
	descOut, err := ecsClient.DescribeServices(&ecs.DescribeServicesInput{
		Cluster:  &deployStatus.ClusterName,
		Services: []*string{&deployStatus.ServiceName},
	})
	if err != nil {
		return nil, err
	}
	if len(descOut.Services) != 1 {
		return nil, errors.Errorf("unable to find the service %s", deployStatus.ServiceName)
	}
	service := descOut.Services[0]
	curUser := "on-demand"
	groupName := fmt.Sprintf("%s-%s-%s", team, deployStatus.ServiceName, curUser)
	taskStartedBy := fmt.Sprintf("ecs-deploy-%s", curUser)
	taskIdentifier := fmt.Sprintf("%s:%d", *taskDefinition.Family, *taskDefinition.Revision)

	runOut, err := ecsClient.RunTask(&ecs.RunTaskInput{
		Cluster:              &deployStatus.ClusterName,
		Count:                aws.Int64(1),
		Group:                &groupName,
		PlacementConstraints: service.PlacementConstraints,
		PlacementStrategy:    service.PlacementStrategy,
		StartedBy:            &taskStartedBy,
		TaskDefinition:       &taskIdentifier,
	})
	if err != nil {
		return nil, err
	}
	if len(runOut.Tasks) != 1 {
		return nil, errors.New("could not find ran task in RunTask output")
	}
	ranTask := runOut.Tasks[0]
	return t.blockForTaskToRun(ecsClient, ranTask)
}

func (t *TaskCreation) blockForTaskToRun(ecsClient *ecs.ECS, task *ecs.Task) (*ecs.Task, error) {
	i := 0
	for *task.LastStatus == ecs.DesiredStatusPending && i < 30 {
		fmt.Println("Sleeping till task not in PENDING state...")
		time.Sleep(time.Second)
		taskDesc, err := ecsClient.DescribeTasks(&ecs.DescribeTasksInput{
			Cluster: task.ClusterArn,
			Tasks:   []*string{task.TaskArn},
		})
		if err != nil {
			return nil, err
		}
		if len(taskDesc.Tasks) != 1 {
			return nil, errors.New("could not find task in search")
		}
		task = taskDesc.Tasks[0]
		i++
	}
	if i == 30 {
		return nil, errors.New("task stayed pending")
	}
	return task, nil
}

func (d *TaskCreation) defaultTemplateParams(team string, service string, env string, tag string, extraData map[string]string) map[string]string {
	ret := map[string]string{
		"team":           team,
		"service":        service,
		"environment":    env,
		"git_commit":     tag,
		"task_cpu":       "1024",
		"task_mem":       "512",
		"container_port": "8080",
		"task_name":      "team" + "-" + service,
		"image":          d.DockerHelper.DockerHost + "/" + team + "/" + service + ":" + tag,
	}
	for k, v := range extraData {
		ret[k] = v
	}
	return ret
}

func (d *TaskCreation) createTask(in *taskCreationInput) (*ecs.TaskDefinition, error) {
	version, err2 := d.DeployDB.deployedVersion(in.team, in.service, in.tag)
	if err2 == nil && version != "" {
		in.tag = version
	}
	if err := d.DockerHelper.dockerImageExists(in.team, in.service, in.tag); err != nil {
		return nil, errors.Wrap(err, "unable to validate docker image still exists")
	}
	tv := d.defaultTemplateParams(in.team, in.service, in.env, in.tag, in.taskData)
	return d.createECSTask(in.ecsClient, in.taskTemplate, tv)
}

func NewAWSSession(region string, awsRole string, profile string) (*session.Session, error) {
	regionCfg := &aws.Config{
		Region: &region,
	}
	opts := session.Options{
		SharedConfigState: session.SharedConfigEnable,
		Profile:           profile,
	}
	opts.Config.MergeIn(regionCfg)
	awsSessionForRole, err := session.NewSessionWithOptions(opts)
	if err != nil {
		return nil, err
	}
	if awsRole == "" {
		return awsSessionForRole, nil
	}
	stsclient := sts.New(awsSessionForRole)
	arp := &stscreds.AssumeRoleProvider{
		ExpiryWindow: 10 * time.Second,
		RoleARN:      awsRole,
		Client:       stsclient,
	}
	credsARP := credentials.NewCredentials(arp)
	opts.Config.MergeIn(&aws.Config{
		Credentials: credsARP,
	})
	return session.NewSessionWithOptions(opts)
}

func NewECSClient(region string, awsRole string, profile string) (*ecs.ECS, error) {
	session, err := NewAWSSession(region, awsRole, profile)
	if err != nil {
		return nil, err
	}
	return ecs.New(session), nil
}

func NewEC2Client(region string, awsRole string, profile string) (*ec2.EC2, error) {
	session, err := NewAWSSession(region, awsRole, profile)
	if err != nil {
		return nil, err
	}
	return ec2.New(session), nil
}

func (d *TaskCreation) createECSTask(ecsClient *ecs.ECS, taskTemplate string, templateInfo map[string]string) (*ecs.TaskDefinition, error) {
	taskInput, err := d.createTaskDefinitionInput(taskTemplate, templateInfo)
	if err != nil {
		return nil, err
	}
	fmt.Println(awsutil.Prettify(taskInput))
	req, out2 := ecsClient.RegisterTaskDefinitionRequest(taskInput)
	if err := req.Send(); err != nil {
		return nil, errors.Wrap(err, "uanble to register task")
	}
	return out2.TaskDefinition, nil
}

func (d *TaskCreation) createTaskDefinitionInput(containerTemplate string, templateInfo map[string]string) (*ecs.RegisterTaskDefinitionInput, error) {
	templateParams := map[string]interface{}{}

	if containerTemplate == "" {
		containerTemplate = defaultTaskTemplateText
	}

	for k, v := range templateInfo {
		templateParams[k] = v
	}

	tmplate, err := template.New("containers").Parse(containerTemplate)
	if err != nil {
		return nil, errors.Wrapf(err, "Unable to parse containers template %s", containerTemplate)
	}
	buf := bytes.Buffer{}
	if err := tmplate.Execute(&buf, templateParams); err != nil {
		return nil, errors.Wrap(err, "Unable to process template")
	}

	abc := ecs.RegisterTaskDefinitionInput{}
	if err := json.NewDecoder(&buf).Decode(&abc); err != nil {
		return nil, errors.Wrap(err, "Unable to read container template file as containers")
	}
	return &abc, nil
}
