package server

import (
	"context"
	"fmt"
	"time"

	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/endpoints"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"

	"code.justin.tv/creator-collab/log/errors"
)

type CWLogs struct {
	cwlogsClient *cloudwatchlogs.CloudWatchLogs
}

func NewCWLogs() (*CWLogs, error) {
	cwlogs := &CWLogs{}
	err := cwlogs.initInnerClient()
	if err != nil {
		return nil, err
	}

	return cwlogs, nil
}

func (c *CWLogs) initInnerClient() error {
	sess, err := session.NewSession(&aws.Config{
		Region:              aws.String("us-west-2"),
		STSRegionalEndpoint: endpoints.RegionalSTSEndpoint,
	})
	if err != nil {
		return errors.Wrap(err, "creating aws session failed")
	}

	c.cwlogsClient = cloudwatchlogs.New(sess)

	return nil
}

func (c *CWLogs) runQuery(ctx context.Context, startQueryInput *cloudwatchlogs.StartQueryInput) (
	*cloudwatchlogs.GetQueryResultsOutput, error) {

	startQueryResponse, err := c.startQueryWithContext(ctx, startQueryInput)
	if err != nil {
		return nil, errors.Wrap(err, "runQuery - StartQueryWithContext request failed")
	}
	if startQueryResponse.QueryId == nil {
		return nil, errors.New("runQuery - startQueryResponse.QueryId was nil")
	}

	queryID := *startQueryResponse.QueryId

	// Poll until the query has completed running.
	for {
		time.Sleep(time.Second * 5)

		innerCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
		queryResults, err := c.getQueryResultsWithContext(innerCtx, &cloudwatchlogs.GetQueryResultsInput{
			QueryId: aws.String(queryID),
		})
		cancel()
		if err != nil {
			if innerCtx.Err() != nil {
				continue
			}

			return nil, errors.Wrap(err, "runQuery - GetQueryResultsWithContext failed")
		}

		if queryResults == nil {
			return nil, errors.New("queryResults was nil")
		}
		if queryResults.Status == nil {
			return nil, errors.New("queryResults.Status was nil")
		}
		queryStatus := *queryResults.Status

		switch queryStatus {
		case cloudwatchlogs.QueryStatusScheduled, cloudwatchlogs.QueryStatusRunning:
			continue

		case cloudwatchlogs.QueryStatusFailed, cloudwatchlogs.QueryStatusCancelled:
			return nil, errors.New(fmt.Sprintf("query ended with \"%s\" status", queryStatus))

		case cloudwatchlogs.QueryStatusComplete:
			return queryResults, nil

		default:
			return nil, errors.New("unrecognized query status", errors.Fields{
				"status": queryStatus,
			})
		}
	}
}

func (c *CWLogs) startQueryWithContext(ctx context.Context, startQueryInput *cloudwatchlogs.StartQueryInput) (*cloudwatchlogs.StartQueryOutput, error) {
	var startQueryResponse *cloudwatchlogs.StartQueryOutput
	var err error

	c.retryIfTokenExpires(func() error {
		startQueryResponse, err = c.cwlogsClient.StartQueryWithContext(ctx, startQueryInput)
		return err
	})

	return startQueryResponse, err
}

func (c *CWLogs) getQueryResultsWithContext(ctx context.Context, getQueryResultsInput *cloudwatchlogs.GetQueryResultsInput) (
	*cloudwatchlogs.GetQueryResultsOutput, error) {

	var getQueryResultsOutput *cloudwatchlogs.GetQueryResultsOutput
	var err error

	c.retryIfTokenExpires(func() error {
		getQueryResultsOutput, err = c.cwlogsClient.GetQueryResultsWithContext(ctx, getQueryResultsInput)
		return err
	})

	return getQueryResultsOutput, err
}

// Retry request once if it fails because the client's credentials expired.
// This is meant to prevent failures when users are using ada to provide short-lived credentials.
func (c *CWLogs) retryIfTokenExpires(queryFunc func() error) {
	for i := 0; i < 2; i++ {
		err := queryFunc()

		if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "ExpiredTokenException" {
			c.initInnerClient()
			continue
		}

		break
	}
}
