package dynamo

import (
	"context"
	"errors"
	"fmt"
	"strings"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
)

type UpdateAllowlistEntryRequest struct {
	GameID         string
	ClientID       string
	OrganizationID string
	Products       []string
}

func (a *allowlist) UpdateAllowlistEntry(ctx context.Context, r *UpdateAllowlistEntryRequest) (*AllowlistEntry, error) {
	const conditionExpression = "attribute_exists(GameID) AND attribute_exists(ClientID)"
	attributeValues, updateExpression := r.AttributeValuesAndUpdateExpression()
	input := &dynamodb.UpdateItemInput{
		TableName: &a.tableName,
		Key: map[string]*dynamodb.AttributeValue{
			"GameID":   {S: aws.String(r.GameID)},
			"ClientID": {S: aws.String(r.ClientID)},
		},
		ConditionExpression:       aws.String(conditionExpression),
		ReturnValues:              aws.String(dynamodb.ReturnValueAllNew),
		UpdateExpression:          aws.String(updateExpression),
		ExpressionAttributeValues: attributeValues,
	}
	output, err := a.dynamo.UpdateItemWithContext(ctx, input)
	if err != nil {
		var conditionErr *dynamodb.ConditionalCheckFailedException
		if errors.As(err, &conditionErr) {
			return nil, &NotFoundError{
				gameID:   r.GameID,
				clientID: r.ClientID,
			}
		}
		return nil, err
	}

	var entry AllowlistEntry
	if err := dynamodbattribute.UnmarshalMap(output.Attributes, &entry); err != nil {
		return nil, err
	}

	return &entry, nil
}

func (r *UpdateAllowlistEntryRequest) AttributeValuesAndUpdateExpression() (map[string]*dynamodb.AttributeValue, string) {
	attributeValues := map[string]*dynamodb.AttributeValue{}
	var statements []string

	if r.OrganizationID != "" {
		attributeValues[":organization_id"] = &dynamodb.AttributeValue{S: aws.String(r.OrganizationID)}
		statements = append(statements, "OrganizationID = :organization_id")
	}
	if len(r.Products) != 0 {
		attributeValues[":products"] = &dynamodb.AttributeValue{SS: aws.StringSlice(r.Products)}
		statements = append(statements, "Products = :products")
	}

	return attributeValues, fmt.Sprintf("SET %s", strings.Join(statements, ", "))
}

type NotFoundError struct {
	gameID   string
	clientID string
}

func (e *NotFoundError) Error() string {
	return fmt.Sprintf("entry for game ID %q and client ID %q not found", e.gameID, e.clientID)
}
