package cmd

import (
	"context"
	"fmt"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/athena"
	"github.com/aws/aws-sdk-go/service/athena/athenaiface"
	"github.com/pkg/errors"
	"github.com/spf13/cobra"

	"code.justin.tv/eventbus/admin-cli/internal/environment"
	"code.justin.tv/eventbus/admin-cli/internal/output"
)

var startTimeString string
var endTimeString string
var resultCount int

const athenaQueryTimeout = 2 * time.Minute

func init() {
	authFieldAnalysisCmd.Flags().StringVarP(&startTimeString, "start", "s", "1h", "start time, as a duration back in time from the current time (e.g. 1h or 6h30m)")
	authFieldAnalysisCmd.Flags().StringVarP(&endTimeString, "end", "e", "0s", "end time, as a duration back in time from the current time (e.g. 1h or 6h30m)")
	authFieldAnalysisCmd.Flags().IntVarP(&resultCount, "count", "n", 10, "number of results to return in publisher and subscriber tables")
	rootCmd.AddCommand(authFieldAnalysisCmd)
}

var athenaQueryTemplate = `
SELECT 
         eventname,
         useridentity.accountid,
         JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.EventType') as eventtype,
         JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.Environment') as environment,
         JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.MessageName') as messagename,
         JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.FieldName') as fieldname,
         COUNT(*) as kmscalls
FROM cloudtrail_logs_eventbus_%s_authorized_fields_cloudtrail
WHERE eventsource = 'kms.amazonaws.com'
         AND resources[1].arn = '%s'
		 AND eventtime > '%s'
		 AND eventtime < '%s'
         AND JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.EventType') IS NOT NULL
         AND JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.Environment') IS NOT NULL
         AND JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.MessageName') IS NOT NULL
         AND JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.FieldName') IS NOT NULL
GROUP BY 
    useridentity.accountid, 
    eventname,
    JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.EventType'),    
    JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.Environment'), 
    JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.MessageName'), 
    JSON_EXTRACT_SCALAR(requestparameters, '$.encryptionContext.FieldName')
ORDER BY kmscalls DESC
`

var authFieldAnalysisCmd = &cobra.Command{
	Use:   "auth-field-analysis",
	Short: "Prints a readout of authorized field access calls against KMS",
	Run: func(cmd *cobra.Command, args []string) {
		ctx := context.Background()
		config, err := environment.Resolve()
		if err != nil {
			output.FatalError(err)
		}

		startTimeDiff, err := time.ParseDuration(startTimeString)
		if err != nil {
			output.FatalError(errors.Wrap(err, "invalid start time"))
		}
		endTimeDiff, err := time.ParseDuration(endTimeString)
		if err != nil {
			output.FatalError(errors.Wrap(err, "invalid end time"))
		}

		now := time.Now().UTC()
		start := now.Add(-1 * startTimeDiff)
		end := now.Add(-1 * endTimeDiff)
		if start.Unix() > end.Unix() {
			output.FatalError(errors.New("start time is after end time"))
		}

		runner := NewAthenaQueryRunner(config.Environment, config.AuthorizedFieldS3Bucket, config.AuthorizedFieldCMK)

		fmt.Println("Starting query!")
		err = runner.Start(ctx, start, end)
		if err != nil {
			output.FatalError(err)
		}

		fmt.Print("Waiting on query execution...")
		err = runner.Wait(ctx, athenaQueryTimeout)
		if err != nil {
			output.FatalError(err)
		}

		fmt.Println("Fetching results")
		results, err := runner.Results(ctx, resultCount)
		if err != nil {
			output.FatalError(err)
		}
		output.Newline()

		fmt.Println("Publisher KMS Usage Data")
		if err := output.Table(
			[]string{"AWS Account ID", "EventType", "Environment", "MessageName", "FieldName", "KMS Calls"},
			KMSUsageDataTableFormatter(results.PublisherData),
		); err != nil {
			output.FatalError(err)
		}
		output.Newline()

		fmt.Println("Subscriber KMS Usage Data")
		if err := output.Table(
			[]string{"AWS Account ID", "EventType", "Environment", "MessageName", "FieldName", "KMS Calls"},
			KMSUsageDataTableFormatter(results.SubscriberData),
		); err != nil {
			output.FatalError(err)
		}
	},
}

type QueryResults struct {
	PublisherData  []*KMSUsageData
	SubscriberData []*KMSUsageData
}

type KMSUsageData struct {
	AWSAccountID string
	EventType    string
	Environment  string
	MessageName  string
	FieldName    string
	CallCount    string // string because this is the type Athena returns
}

func (d *KMSUsageData) AsTableRow() []string {
	return []string{
		d.AWSAccountID,
		d.EventType,
		d.Environment,
		d.MessageName,
		d.FieldName,
		d.CallCount,
	}
}

// KMSUsageDataTableFormatter is used to cast an array of KMSUsageData structs
// to a type that implements output.TableFormatter for easy table generation
// for display on stdout
type KMSUsageDataTableFormatter []*KMSUsageData

func (t KMSUsageDataTableFormatter) AsTableData() [][]string {
	data := make([][]string, 0)
	for _, datum := range []*KMSUsageData(t) {
		data = append(data, datum.AsTableRow())
	}
	return data
}

type AthenaQueryRunner struct {
	client      athenaiface.AthenaAPI
	queryID     string
	cmkARN      string
	s3Bucket    string
	environment string
}

func NewAthenaQueryRunner(env, s3Bucket, cmkARN string) *AthenaQueryRunner {
	return &AthenaQueryRunner{
		client:      athena.New(session.Must(session.NewSession(&aws.Config{Region: aws.String("us-west-2")}))),
		s3Bucket:    s3Bucket,
		cmkARN:      cmkARN,
		environment: env,
	}
}

func (a *AthenaQueryRunner) Start(ctx context.Context, start, end time.Time) error {
	timeFormat := "2006-01-02T15:04:05Z"
	query := fmt.Sprintf(athenaQueryTemplate, a.environment, a.cmkARN, start.Format(timeFormat), end.Format(timeFormat))
	res, err := a.client.StartQueryExecutionWithContext(ctx, &athena.StartQueryExecutionInput{
		QueryExecutionContext: &athena.QueryExecutionContext{
			Database: aws.String("default"),
		},
		ResultConfiguration: &athena.ResultConfiguration{
			OutputLocation: aws.String(fmt.Sprintf("s3://%s/query-results", a.s3Bucket)),
		},
		QueryString: aws.String(query),
	})
	if err != nil {
		return err
	}
	a.queryID = aws.StringValue(res.QueryExecutionId)
	return nil
}

func (a *AthenaQueryRunner) Wait(ctx context.Context, maxWait time.Duration) error {
	checker := time.NewTicker(2 * time.Second)
	timeout := time.NewTicker(maxWait)
	for {
		select {
		case <-checker.C:
			fmt.Print(".")
			if done, err := a.queryComplete(ctx, a.queryID); err != nil {
				return err
			} else if done {
				fmt.Println("") // drop line from dots
				return nil
			}
		case <-timeout.C:
			fmt.Println("") // drop line from dots
			return errors.New("timeout waiting for query execution")
		}
	}
}

func (a *AthenaQueryRunner) queryComplete(ctx context.Context, queryID string) (bool, error) {
	res, err := a.client.GetQueryExecutionWithContext(ctx, &athena.GetQueryExecutionInput{
		QueryExecutionId: aws.String(queryID),
	})
	if err != nil {
		return false, err
	}

	state := aws.StringValue(res.QueryExecution.Status.State)
	switch state {
	case "SUCCEEDED":
		return true, nil
	case "FAILED", "CANCELLED":
		return false, errors.New("query returned with status " + state)
	default:
		return false, nil
	}
}

func (a *AthenaQueryRunner) Results(ctx context.Context, maxResults int) (*QueryResults, error) {
	publisherData := make([]*KMSUsageData, 0)
	subscriberData := make([]*KMSUsageData, 0)
	dataHandler := func(output *athena.GetQueryResultsOutput, isLastPage bool) bool {
		for _, row := range output.ResultSet.Rows {
			if len(row.Data) != 7 {
				fmt.Println("WARNING: skipping result row with unexpected number of columns")
				continue
			}
			kmsUsage := &KMSUsageData{
				AWSAccountID: aws.StringValue(row.Data[1].VarCharValue),
				EventType:    aws.StringValue(row.Data[2].VarCharValue),
				Environment:  aws.StringValue(row.Data[3].VarCharValue),
				MessageName:  aws.StringValue(row.Data[4].VarCharValue),
				FieldName:    aws.StringValue(row.Data[5].VarCharValue),
				CallCount:    aws.StringValue(row.Data[6].VarCharValue),
			}
			if aws.StringValue(row.Data[0].VarCharValue) == "Decrypt" && len(subscriberData) < maxResults {
				subscriberData = append(subscriberData, kmsUsage)
			} else if aws.StringValue(row.Data[0].VarCharValue) == "GenerateDataKey" && len(publisherData) < maxResults {
				publisherData = append(publisherData, kmsUsage)
			}
		}
		return !isLastPage
	}

	err := a.client.GetQueryResultsPagesWithContext(ctx, &athena.GetQueryResultsInput{
		QueryExecutionId: aws.String(a.queryID),
	}, dataHandler)
	if err != nil {
		return nil, err
	}
	return &QueryResults{
		PublisherData:  publisherData,
		SubscriberData: subscriberData,
	}, nil
}
