package dynamocursor

import (
	"context"
	"encoding/json"
	"sync"

	"time"

	"code.justin.tv/hygienic/errors"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/cep21/circuit"
)

// Logger accepts errors while the factory's start method is running
type Logger interface {
	Log(keyvals ...interface{})
}

// Factory can create cursors for DynamoDB tables
type Factory struct {
	Client     *dynamodb.DynamoDB
	TableNames func() []string
	Circuit    *circuit.Circuit
	Log        Logger

	RefreshDuration time.Duration

	mu sync.RWMutex
	// keyed on table name
	loadedTableInfo map[string]*dynamodbTableInfo

	once    sync.Once
	onClose chan struct{}
}

type dynamodbTableInfo struct {
	TableDescription *dynamodb.TableDescription
	TableKeys        PrimaryKey
	// keyed on index name
	Indexes map[string]PrimaryKey
}

// PrimaryKey controls how items are identified as unique in dynamodb
type PrimaryKey interface {
	//Type      KeyType
	ExtractKeys(item map[string]*dynamodb.AttributeValue) (string, string, error)
	ExtractAttributeValues(partitionValue string, sortValue string) map[string]*dynamodb.AttributeValue
}

type compositePrimaryKey struct {
	Partition Key
	Sort      Key
}

func (c compositePrimaryKey) ExtractKeys(item map[string]*dynamodb.AttributeValue) (string, string, error) {
	pk, err := c.Partition.ExtractKey(item)
	if err != nil {
		return "", "", err
	}

	sk, err := c.Sort.ExtractKey(item)
	if err != nil {
		return "", "", err
	}
	return pk, sk, nil
}

func (c compositePrimaryKey) ExtractAttributeValues(partitionValue string, sortValue string) map[string]*dynamodb.AttributeValue {
	return map[string]*dynamodb.AttributeValue{
		c.Partition.AttrName(): c.Partition.AttrValue(partitionValue),
		c.Sort.AttrName():      c.Sort.AttrValue(sortValue),
	}
}

type simplePrimaryKey struct {
	Partition Key
}

func (c simplePrimaryKey) ExtractKeys(item map[string]*dynamodb.AttributeValue) (string, string, error) {
	pk, err := c.Partition.ExtractKey(item)
	if err != nil {
		return "", "", err
	}

	return pk, "", nil
}

func (c simplePrimaryKey) ExtractAttributeValues(partitionValue string, sortValue string) map[string]*dynamodb.AttributeValue {
	return map[string]*dynamodb.AttributeValue{
		c.Partition.AttrName(): c.Partition.AttrValue(partitionValue),
	}
}

// KeyType describes the type of index the composite key is on
type KeyType string

const (
	// TableKey is for types on the table itself
	TableKey KeyType = "table"
	// GsiKey is for types on a GSI
	GsiKey KeyType = "gsi"
	// LsiKey is for types on a LSI
	LsiKey KeyType = "lsi"
)

var _ json.Marshaler = &Factory{}

// MarshalJSON encodes the factory as JSON
func (f *Factory) MarshalJSON() ([]byte, error) {
	f.mu.RLock()
	defer f.mu.RUnlock()
	return json.Marshal(f.loadedTableInfo)
}

func (f *Factory) refreshDuration() time.Duration {
	if f.RefreshDuration.Nanoseconds() == 0 {
		return time.Minute * 5
	}
	return f.RefreshDuration
}

// ExclusiveStartKey generates a ExclusiveStartKey parameter for a query or scan from the passed in cursor
func (f *Factory) ExclusiveStartKey(cursor Cursor, tableName string, indexName string) (map[string]*dynamodb.AttributeValue, error) {
	if cursor.IsZero() {
		return nil, nil
	}
	tableInfo, ok := f.loadedTableInfo[tableName]
	if !ok {
		return nil, errors.Errorf("unable to find cursor for table %s", tableName)
	}
	m1 := tableInfo.TableKeys.ExtractAttributeValues(cursor.TablePartitionValue, cursor.TableSortValue)
	if indexName == "" {
		return m1, nil
	}
	indexInfo, ok := tableInfo.Indexes[indexName]
	if !ok {
		return nil, errors.Errorf("unable to find cursor for table %s", tableName)
	}
	for k, v := range indexInfo.ExtractAttributeValues(cursor.IndexPartitionValue, cursor.IndexSortValue) {
		m1[k] = v
	}
	return m1, nil
}

// Cursor generates a cursor for a table from the item's return values.  Can optionally use an index, if index is not empty
func (f *Factory) Cursor(item map[string]*dynamodb.AttributeValue, tableName string, indexName string) (Cursor, error) {
	if item == nil {
		return Cursor{}, nil
	}
	f.mu.RLock()
	defer f.mu.RUnlock()
	tableInfo, ok := f.loadedTableInfo[tableName]
	var ret Cursor
	if !ok {
		return ret, errors.Errorf("unable to find cursor for table %s", tableName)
	}
	var err error
	ret.TablePartitionValue, ret.TableSortValue, err = tableInfo.TableKeys.ExtractKeys(item)
	if err != nil {
		return Cursor{}, err
	}
	if indexName == "" {
		return ret, nil
	}
	indexInfo, ok := f.loadedTableInfo[tableName].Indexes[indexName]
	if !ok {
		return ret, errors.Errorf("unable to find hash key for index %s", indexName)
	}
	ret.IndexPartitionValue, ret.IndexSortValue, err = indexInfo.ExtractKeys(item)
	if err != nil {
		return Cursor{}, err
	}
	return ret, nil
}

// Setup initially populates the factory
func (f *Factory) Setup() error {
	f.once.Do(func() {
		f.onClose = make(chan struct{})
	})
	return f.Refresh(context.Background())
}

// Start is intended to be run in a background goroutine and will refresh table index information at RefreshDuration
// sleeps
func (f *Factory) Start() error {
	f.once.Do(func() {
		f.onClose = make(chan struct{})
	})
	for {
		select {
		case <-f.onClose:
			return nil
		case <-time.After(f.refreshDuration()):
			if err := f.Refresh(context.Background()); err != nil {
				if f.Log != nil {
					f.Log.Log("err", err, "unable to refresh table indexes")
				}
			}
		}
	}
}

// Close will end the Start thread and stop any new ones
func (f *Factory) Close() error {
	f.once.Do(func() {
		f.onClose = make(chan struct{})
	})
	close(f.onClose)
	return nil
}

// Refresh the table information from Tables()
func (f *Factory) Refresh(ctx context.Context) error {
	if f.TableNames == nil {
		return nil
	}
	tableNames := f.TableNames()
	loadedTableInfo := make(map[string]*dynamodbTableInfo, len(tableNames))
	for _, tableName := range tableNames {
		var err error
		loadedTableInfo[tableName], err = f.dynamodbTableInfo(ctx, tableName)
		if err != nil {
			return errors.Wrapf(err, "Unable to load table info: %s", tableName)
		}
	}
	f.mu.Lock()
	defer f.mu.Unlock()
	f.loadedTableInfo = loadedTableInfo
	return nil
}

func (f *Factory) dynamodbTableInfo(ctx context.Context, tableName string) (*dynamodbTableInfo, error) {
	req, descTableOutput := f.Client.DescribeTableRequest(&dynamodb.DescribeTableInput{
		TableName: &tableName,
	})

	err := f.Circuit.Run(ctx, func(ctx context.Context) error {
		req.SetContext(ctx)
		return req.Send()
	})

	if err != nil {
		return nil, err
	}

	info := &dynamodbTableInfo{
		TableDescription: descTableOutput.Table,
	}
	attrDefs := descTableOutput.Table.AttributeDefinitions
	info.TableKeys, err = extractKeysFromTable(descTableOutput.Table.KeySchema, attrDefs, TableKey)
	if err != nil {
		return nil, err
	}

	gsi, lsi := descTableOutput.Table.GlobalSecondaryIndexes, descTableOutput.Table.LocalSecondaryIndexes
	info.Indexes, err = extractKeysFromIndex(gsi, lsi, attrDefs)
	if err != nil {
		return nil, err
	}
	return info, nil
}

func extractKeysFromTable(ks []*dynamodb.KeySchemaElement, defs []*dynamodb.AttributeDefinition, keyType KeyType) (PrimaryKey, error) {
	var partition Key
	var sort Key
	for _, k := range ks {
		attrName, err := extractAttributeType(*k.AttributeName, defs)
		if err != nil {
			return nil, err
		}
		dk, err := newDynamodbKey(attrName, *k.AttributeName)
		if err != nil {
			return nil, err
		}
		switch *k.KeyType {
		case "HASH":
			partition = dk
		case "RANGE":
			sort = dk
		}
	}
	if partition == nil {
		// NOTE: not all tables will have a range key specified, however, indexes should have both
		return nil, errors.New("unable to create keys for table")
	}
	if sort == nil {
		if keyType == GsiKey || keyType == LsiKey {
			return nil, errors.New("a GSI or LSI requires a range key")
		}
		return simplePrimaryKey{
			Partition: partition,
		}, nil
	}
	return compositePrimaryKey{
		Partition: partition,
		Sort:      sort,
	}, nil
}

func extractAttributeType(name string, defs []*dynamodb.AttributeDefinition) (string, error) {
	for _, def := range defs {
		if *def.AttributeName == name {
			return *def.AttributeType, nil
		}
	}
	return "", errors.Errorf("unable to find attribute %s", name)
}

func extractKeysFromIndex(gs []*dynamodb.GlobalSecondaryIndexDescription, ls []*dynamodb.LocalSecondaryIndexDescription, defs []*dynamodb.AttributeDefinition) (map[string]PrimaryKey, error) {
	ret := make(map[string]PrimaryKey, len(gs)+len(ls))

	for _, g := range gs {
		var err error
		ret[*g.IndexName], err = extractKeysFromTable(g.KeySchema, defs, GsiKey)
		if err != nil {
			return nil, err
		}
	}

	for _, l := range ls {
		var err error
		ret[*l.IndexName], err = extractKeysFromTable(l.KeySchema, defs, LsiKey)
		if err != nil {
			return nil, err
		}
	}
	return ret, nil
}
