package dynamodb

import (
	"context"
	"fmt"
	"time"

	"code.justin.tv/cb/sauron/activity"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
	"github.com/pkg/errors"

	log "github.com/sirupsen/logrus"
)

// Only count events as in the alert queue if they occurred sooner than this time in hours
const (
	QueueDuration = 2 // in hours
	MaxAlertLimit = 3000

	maxRetryAttempts   = 3
	maxDynamoBatchSize = 25
)

var validActivityTypes = map[string]struct{}{
	activity.TypeAutoHostStart:                 {},
	activity.TypeBitsUsage:                     {},
	activity.TypeFollow:                        {},
	activity.TypeHostStart:                     {},
	activity.TypePrimeResubscriptionSharing:    {},
	activity.TypePrimeSubscription:             {},
	activity.TypeRaiding:                       {},
	activity.TypeResubscriptionSharing:         {},
	activity.TypeSubscription:                  {},
	activity.TypeSubscriptionGiftingCommunity:  {},
	activity.TypeSubscriptionGiftingIndividual: {},
}

// SetAlertStatus sets the alert status of an existing event, and returns the event type that was updated
func (c *Client) SetAlertStatus(ctx context.Context, channelID string, activityID string, newStatus string) (*Activity, error) {
	// We use a query instead of a get here so we can specify the index to use
	// (that field is not available in GetItemInput)
	queryInput := &dynamodb.QueryInput{
		KeyConditionExpression: aws.String("#C = :channelID AND #I = :activityID"),
		ExpressionAttributeNames: map[string]*string{
			"#C": aws.String("channel_id"),
			"#I": aws.String("id"),
		},
		ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
			":channelID":  {S: aws.String(channelID)},
			":activityID": {S: aws.String(activityID)},
		},
		IndexName: aws.String("ChannelEventIndex"),
		TableName: aws.String(c.activityTable),
	}

	queryOutput, err := c.dynamoDB.QueryWithContext(ctx, queryInput)
	if err != nil {
		return nil, errors.Wrap(err, "dynamodb: failed to get timestamp for alert update")
	}

	// If we get 0 items, the channelID/activityID pair does not exist. But if we get more than
	// one item, we're not going to be sure which one to update, so we need to fail.
	if len(queryOutput.Items) != 1 {
		return nil, errors.Wrap(err, "dynamodb: no (or multiple) items found when querying for activity timestamp")
	}

	var activity Activity
	if err := dynamodbattribute.UnmarshalMap(queryOutput.Items[0], &activity); err != nil {
		return nil, errors.Wrap(err, "dynamodb: failed to convert to activity")
	}

	// Ensure that we only set the alert status for an activity type that is allowed to show alerts.
	if _, ok := validActivityTypes[activity.Type]; !ok {
		return nil, fmt.Errorf("dynamodb: cannot set alert status for invalid activity type '%s'", activity.Type)
	}

	input := &dynamodb.UpdateItemInput{
		TableName: aws.String(c.activityTable),
		Key: map[string]*dynamodb.AttributeValue{
			"channel_id": {S: aws.String(channelID)},
			"timestamp":  {S: aws.String(activity.Timestamp.Format(time.RFC3339Nano))},
		},
		ExpressionAttributeNames: map[string]*string{
			"#S": aws.String("alert_status"),
		},
		ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
			":status": {S: aws.String(newStatus)},
		},
		UpdateExpression: aws.String("SET #S = :status"),
		ReturnValues:     aws.String("ALL_NEW"),
	}

	result, err := c.dynamoDB.UpdateItemWithContext(ctx, input)

	if err != nil {
		return nil, errors.Wrap(err, "dynamodb: failed to set alert status")
	}

	var newActivity Activity
	if err := dynamodbattribute.UnmarshalMap(result.Attributes, &newActivity); err != nil {
		return nil, errors.Wrap(err, "dynamodb: failed to convert to activity")
	}

	return &newActivity, nil
}

// GetAlertQueue returns events in the alert queue ordered by oldest event first. The queue is limited
// to events that occurred within the timeframe determined by the queueDuration constant.
// There is no limit on the size of the items in the queue.
func (c *Client) GetAlertQueue(ctx context.Context, channelID string, before time.Time, limit int) ([]Activity, error) {
	oldestAllowableTime := time.Now().Add(-(time.Hour * QueueDuration))
	if before.Before(oldestAllowableTime) {
		before = oldestAllowableTime
	}

	if limit < 0 || limit > MaxAlertLimit {
		limit = MaxAlertLimit
	}

	input := &dynamodb.QueryInput{
		KeyConditionExpression: aws.String("#C = :channelID AND #T > :oldest"),
		ExpressionAttributeNames: map[string]*string{
			"#C": aws.String("channel_id"),
			"#T": aws.String("timestamp"),
			"#A": aws.String("alert_status"),
		},
		ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
			":channelID": {S: aws.String(channelID)},
			":oldest":    {S: aws.String(before.UTC().Format(time.RFC3339Nano))},
			":queued":    {S: aws.String(string(AlertStatusQueued))},
		},
		ExclusiveStartKey: nil,
		Limit:             aws.Int64(int64(limit)),
		FilterExpression:  aws.String("#A = :queued"),
		ScanIndexForward:  aws.Bool(true), // ascending order: oldest items first
		TableName:         aws.String(c.activityTable),
	}

	var activities []Activity

	for {
		output, err := c.dynamoDB.QueryWithContext(ctx, input)
		if err != nil {
			return nil, errors.Wrap(err, "dynamodb: failed to get activities")
		}

		for _, item := range output.Items {
			var activity Activity

			if err := dynamodbattribute.UnmarshalMap(item, &activity); err != nil {
				return nil, errors.Wrap(err, "dynamodb: failed to convert to activity")
			}

			activities = append(activities, activity)

			if len(activities) == limit {
				return activities, nil
			}
		}

		if output.LastEvaluatedKey == nil {
			break
		}

		input.SetExclusiveStartKey(output.LastEvaluatedKey)
	}

	return activities, nil
}

// DeleteAlertQueue removes all items from the current alert queue. This is done by first retrieving the entire
// queue, then setting all the AlertStatus values to rejected, and then batch writing those items back into
// the table. We cannot actually remove the existing table items, since they still correspond to valid
// channel activity
func (c *Client) DeleteAlertQueue(ctx context.Context, channelID string) error {
	currQueue, err := c.GetAlertQueue(ctx, channelID, time.Time{}, -1)
	if err != nil {
		return err
	}

	// Copy the old queue into a new slice so we can update the
	// alert statuses correctly
	newAlertStatus := AlertStatusPurged
	updatedQueue := make([]Activity, len(currQueue))
	copy(updatedQueue, currQueue)

	for i := range currQueue {
		updatedQueue[i].AlertStatus = &newAlertStatus
	}

	for i := 0; i < len(updatedQueue); i += maxDynamoBatchSize {
		end := i + maxDynamoBatchSize
		if end > len(updatedQueue) {
			end = len(updatedQueue)
		}

		alertsToDelete := updatedQueue[i:end]
		err = c.sendBatchWriteRequest(ctx, alertsToDelete)
		if err != nil {
			// Log here, but continue with future batches
			log.WithFields(log.Fields{
				"channel_id": channelID,
				"batch":      alertsToDelete,
			}).WithError(err).Error("dynamodb: failed to batch write alert statuses")
		}
	}

	return nil
}

// sendBatchRequest creates and sends a batch of writes to dynamo. It implements some
// retry logic using a linear backoff in case some items remain unprocessed.
func (c *Client) sendBatchWriteRequest(ctx context.Context, activities []Activity) error {
	writeRequests := make([]*dynamodb.WriteRequest, len(activities))

	for idx, activity := range activities {
		item, err := dynamodbattribute.MarshalMap(activity)
		if err != nil {
			return err
		}

		writeRequests[idx] = &dynamodb.WriteRequest{
			PutRequest: &dynamodb.PutRequest{
				Item: item,
			},
		}
	}

	input := &dynamodb.BatchWriteItemInput{
		RequestItems: map[string][]*dynamodb.WriteRequest{
			c.activityTable: writeRequests,
		},
	}

	for attempts := 0; attempts < maxRetryAttempts; attempts++ {
		result, err := c.dynamoDB.BatchWriteItemWithContext(ctx, input)
		if err != nil {
			return err
		}

		if len(result.UnprocessedItems) == 0 {
			return nil
		}

		backoff := 2 * (attempts + 1)
		input = &dynamodb.BatchWriteItemInput{
			RequestItems: result.UnprocessedItems,
		}

		time.Sleep(time.Duration(backoff) * time.Second)
	}

	log.WithFields(log.Fields{
		"unprocessed_items": input.RequestItems,
	}).Error("dynamodb: max retries reached, but still found unprocessed items when deleting alert queue")

	return nil
}
