package source

import (
	"context"
	"fmt"
	"math"
	"net/http"
	"strconv"
	"strings"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
	"github.com/aws/aws-sdk-go/aws/endpoints"
	"github.com/aws/aws-sdk-go/aws/session"
	ddb "github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/cep21/circuit/v3"
	"golang.org/x/sync/errgroup"

	logging "code.justin.tv/amzn/TwitchLogging"

	"code.justin.tv/amzn/TwitchFeatureStoreClient/clients"
	"code.justin.tv/amzn/TwitchFeatureStoreClient/metadata"
	"code.justin.tv/amzn/TwitchFeatureStoreClient/types"
)

const batchSize = 40
const schemaKey = "ofs_id"

//go:generate counterfeiter -o ../fakes/sourcefake/identifiable_feature_source.go . IdentifiableFeatureSource
// This is a special type and used by the default DDB feature source
type IdentifiableFeatureSource interface {
	FeatureSourceBulkAccess
	GetIdentifier() string
}

type dynamoDBOnlineSource struct {
	metadata.DynamoDBSource
	logger    logging.Logger
	namespace string
	client    clients.DDBSubset
	metadata  map[types.FeatureKey]metadata.Provider
}

var _ IdentifiableFeatureSource = &dynamoDBOnlineSource{}

// TODO what happens if caller does not have permission to assume the table read role?
func CreateDynamoDBOnlineSource(region string, source metadata.DynamoDBSource, namespace string,
	metadata map[types.FeatureKey]metadata.Provider, manager *circuit.Manager,
	config circuit.Config, httpClient *http.Client, clientIdentifier string, logger logging.Logger) (*dynamoDBOnlineSource, error) {

	backend := &dynamoDBOnlineSource{
		DynamoDBSource: source,
		namespace:      namespace,
		metadata:       metadata,
		logger:         logger,
	}
	awsConfig := &aws.Config{
		Region:              aws.String(region),
		STSRegionalEndpoint: endpoints.RegionalSTSEndpoint,
		HTTPClient:          httpClient,
	}

	sess, err := session.NewSession(awsConfig)
	if err != nil {
		return nil, fmt.Errorf("failed to initiate AWS session for DynamoDB online source: %w", err)
	}

	creds := stscreds.NewCredentials(sess, source.GetReadRole())
	ddbIdentifier := source.GetIdentifier()
	if clientIdentifier != "" {
		ddbIdentifier = fmt.Sprintf("%s:%s", ddbIdentifier, clientIdentifier)
	}
	backend.client = clients.NewDDBClient(sess, awsConfig.WithCredentials(creds), manager, config, ddbIdentifier)
	return backend, nil
}

type featureRecord struct {
	entityStringKey    string
	mapFeatureInstance map[types.FeatureKey]*types.FeatureInstance
}

// The returned slice and map contains the same set of featureRecord but for different purposes. We are trade space
// for time here.
func convertToFeatureRecords(features []*types.FeatureInstance) ([]*featureRecord, map[string]*featureRecord) {
	indexer := make(map[string]*featureRecord)
	var records []*featureRecord
	for _, f := range features {
		record, ok := indexer[f.GetEntityStringKey()]
		if !ok {
			record = &featureRecord{
				entityStringKey:    f.GetEntityStringKey(),
				mapFeatureInstance: make(map[types.FeatureKey]*types.FeatureInstance),
			}
			indexer[f.GetEntityStringKey()] = record
			records = append(records, record)
		}
		record.mapFeatureInstance[f.FeatureKey] = f
	}
	return records, indexer
}

func (d *dynamoDBOnlineSource) BulkGet(ctx context.Context, features []*types.FeatureInstance) error {
	// By this time, all features live in one table.

	// Each record in the records slice has a unique entityKey
	sliceRecords, mapRecords := convertToFeatureRecords(features)
	g, ctx := errgroup.WithContext(ctx)
	start := 0
	end := len(sliceRecords)
	out := make(chan *ddb.BatchGetItemOutput, int(math.Ceil(float64(end)/float64(batchSize))))

	// BatchGetItem has limit of 16MB or 100 items per call.
	// https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_BatchGetItem.html
	// DDB limit per item size within 400KB. To simplify, we will limit the batch size to 40 items per call.
	for start < end {
		next := start + batchSize
		if next > end {
			next = end
		}
		func(i, j int) {
			g.Go(func() error {
				resp, err := d.client.BatchGetItemWithContext(ctx, d.buildBatchGetItemRequest(sliceRecords[i:j]))
				if err != nil {
					return err
				}
				out <- resp
				return nil
			})
		}(start, next)

		start = start + batchSize
	}

	err := g.Wait()
	if err != nil {
		return fmt.Errorf("received errors from BatchGetItemWithContext:%w", err)
	}
	close(out)

	// TODO this can be processed in parallel as well
	for o := range out {
		// TODO should we fail when there are UnprocessedItems?
		if len(o.UnprocessedKeys) > 0 {
			return fmt.Errorf("failed to get all features: there are %d unprocessed keys from %s",
				len(o.UnprocessedKeys), d.GetIdentifier())
		}
		items, ok := o.Responses[d.GetTable()]
		if !ok {
			return fmt.Errorf("unexpected error: could not find table %s from respose:%+v", d.GetIdentifier(), items)
		}
		for _, mapAttributes := range items {
			pk, ok := mapAttributes[schemaKey]
			if !ok || pk.S == nil || len(*pk.S) == 0 {
				return fmt.Errorf("unexpecter error: missing primary key from the ddb response: %s", d.GetIdentifier())
			}
			record := mapRecords[*pk.S]
			for name, value := range mapAttributes {
				if name != schemaKey {
					// The rest of the attributes excluding schemaKey should be "{feature_id}@{version}"
					featureKey, err := types.NewFeatureKey(name, d.namespace)
					if err != nil {
						return fmt.Errorf("unexpected error: %w", err)
					}
					instance, ok := record.mapFeatureInstance[featureKey]
					if !ok {
						// Some features are fetched from OFS but user does not ask for it
						// This can happen because of this: https://sage.amazon.com/posts/1083124
						continue
					}
					instance.Value, err = d.convertToFeatureValue(featureKey, value)
					if err != nil {
						return err
					}
				}
			}
		}
	}
	return nil
}

func assertNotNil(a interface{}, expectedType string, fk types.FeatureKey) error {
	if a == nil {
		return fmt.Errorf("expected feature value type:%s, got nil for feature %+v", expectedType, fk)
	}
	return nil
}

func (d *dynamoDBOnlineSource) convertToFeatureValue(fk types.FeatureKey, v *ddb.AttributeValue) (types.FeatureValue, error) {
	s, ok := d.metadata[fk]
	if !ok {
		return nil, fmt.Errorf("failed to found feature metadata for feature:%+v", fk)
	}

	switch s.GetValueDataShape() {
	case types.SCALAR:
		switch s.GetValueDataType() {
		case types.STRING:
			if err := assertNotNil(v.S, "string", fk); err != nil {
				return nil, err
			}
			return &types.StringFeature{Val: *v.S}, nil
		case types.INTEGER:
			if err := assertNotNil(v.N, "number", fk); err != nil {
				return nil, err
			}
			i, err := strconv.ParseInt(*v.N, 0, 64)
			if err != nil {
				return nil, err
			}
			return &types.IntFeature{Val: i}, nil
		case types.FLOAT:
			if err := assertNotNil(v.N, "number", fk); err != nil {
				return nil, err
			}
			f, err := strconv.ParseFloat(*v.N, 64)
			if err != nil {
				return nil, err
			}
			return &types.FloatFeature{Val: f}, nil
		}
	case types.LIST, types.VECTOR:
		if err := assertNotNil(v.L, "list", fk); err != nil {
			return nil, err
		}
		switch s.GetValueDataType() {
		case types.STRING:
			sSlice := make([]string, len(v.L))
			for i, a := range v.L {
				sSlice[i] = *a.S
			}
			return &types.StringSliceFeature{Val: sSlice}, nil
		case types.INTEGER:
			iSlice := make([]int64, len(v.L))
			for i, a := range v.L {
				if err := assertNotNil(a.N, "number list", fk); err != nil {
					return nil, err
				}
				iVal, err := strconv.ParseInt(*a.N, 0, 64)
				if err != nil {
					return nil, fmt.Errorf("failed to convert feature value of %+v as integer slice, error value: %s", fk, *a)
				}
				iSlice[i] = iVal
			}
			return &types.IntSliceFeature{Val: iSlice}, nil
		case types.FLOAT:
			fSlice := make([]float64, len(v.L))
			for i, a := range v.L {
				if err := assertNotNil(a.N, "number list", fk); err != nil {
					return nil, err
				}
				fVal, err := strconv.ParseFloat(*a.N, 64)
				if err != nil {
					return nil, fmt.Errorf("failed to convert feature value of %+v as float slice, error value: %s", fk, *a)
				}
				fSlice[i] = fVal
			}
			return &types.FloatSliceFeature{Val: fSlice}, nil
		}
	case types.BLOB:
		return &types.BlobFeature{Val: v.B}, nil
	}

	return nil, fmt.Errorf("undefined feature value of %+v data type:%d, data shape:%d", fk, s.GetValueDataType(), s.GetValueDataShape())
}

// TODO BatchGetItem can query multiple tables, should we group multiple tables in a single batch call and reduce the number of requests sent to DDB?
// TODO features are not deduped
func (d *dynamoDBOnlineSource) buildBatchGetItemRequest(records []*featureRecord) *ddb.BatchGetItemInput {
	items := make([]map[string]*ddb.AttributeValue, len(records))
	// It's possible that each feature instance requires different feature id, but here we use the union of all
	// feature ids as the projected attributes in a DDB batch request because the BatchGetItem API only allows a
	// single attribute mapping
	attributes := make(map[string]struct{})

	// every record should have a unique EntityStringKey
	for i, record := range records {
		items[i] = map[string]*ddb.AttributeValue{
			schemaKey: {
				S: aws.String(record.entityStringKey),
			},
		}
		for key := range record.mapFeatureInstance {
			attributes[key.AttributeKey()] = struct{}{}
		}
	}

	// Unfortunately, "@" is considered as an invalid token in ProjectExpression..
	expressionMap := make(map[string]*string)
	idx := 0
	// Append the PK to the attribute list otherwise there is no way to map returned items to PK
	uniqueAttributes := make([]string, len(attributes)+1)
	uniqueAttributes[0] = schemaKey

	//"If any of the requested attributes are not found, they do not appear in the result."
	for attribute := range attributes {
		eKey := fmt.Sprintf("#f%d", idx)
		idx += 1
		expressionMap[eKey] = aws.String(attribute)
		uniqueAttributes[idx] = eKey
	}

	return &ddb.BatchGetItemInput{
		RequestItems: map[string]*ddb.KeysAndAttributes{
			d.GetTable(): {
				Keys:                     items,
				ExpressionAttributeNames: expressionMap,
				ProjectionExpression:     aws.String(strings.Join(uniqueAttributes, ",")),
			},
		},
	}
}
