package adapters

import (
	"context"
	"fmt"
	"strconv"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"

	"code.justin.tv/cb/kinesis_processor/adapters/helper"
	"code.justin.tv/cb/kinesis_processor/models"
	"code.justin.tv/cb/kinesis_processor/utils"
	"github.com/aws/aws-sdk-go/aws/session"
	log "github.com/sirupsen/logrus"
)

const (
	TableChannelSessions = "CbChannelSessions"
)

//
// ChannelSessionAdapter processor.
//
type ChannelSessionAdapter interface {
	// BatchSave - saves models into DynamoDB under specific key
	// defined by channel_id and start_time, end_time, broadcast_ids defined in model.
	BatchSave(models []models.ChannelSession) error

	// Update changes EndTime in Sessions DynamoDB
	// Returns error if something went wrong.
	Update(ctx context.Context, obj models.ChannelSession) error

	// GetLast return last ChannelSession given channelID.
	//
	// Returns *models.ChannelSession if found and error
	// if something went wrong.
	GetLast(ctx context.Context, channelID int64) (*models.ChannelSession, error)

	// ExtendOrCreate given MinuteBroadcast model
	// extending last session or creating a new one
	// return session that was affected.
	//
	// Returns models.ChannelSession if found and error
	// if something went wrong.
	ExtendOrCreate(model models.MinuteBroadcast) (*models.ChannelSession, error)

	// GetLastList returns a list of the last ChannelSession objects of the specified size.
	//
	// Returns []*models.ChannelSession if found and error if something went wrong.
	GetLastList(ctx context.Context, channelID int64, size int64) ([]*models.ChannelSession, error)

	// GetSessionsByTimeForPromotions returns a list of the ChannelSession objects of the specified date range.
	// Returns []*models.ChannelSession if found and error if something went wrong.
	GetSessionsByTimeForPromotions(ctx context.Context, channelID int64, startTime time.Time, endTime time.Time) ([]models.ChannelSession, error)
}

type channelSessionAdapter struct {
	client               dynamodbiface.DynamoDBAPI
	minuteAdapter        MinuteBroadcastAdapter
	channelUpdateAdapter ChannelUpdateAdapter
}

// GetLast return last ChannelSession given channelID.
//
// Returns models.ChannelSession if found and error
// if something went wrong.
func (c *channelSessionAdapter) GetLast(ctx context.Context, channelID int64) (*models.ChannelSession, error) {
	keyCondition := aws.String("ChannelID = :channelID")
	conditionAttrValues := map[string]*dynamodb.AttributeValue{
		":channelID": {
			N: aws.String(strconv.FormatInt(channelID, 10)),
		},
	}

	output, err := c.client.QueryWithContext(ctx, &dynamodb.QueryInput{
		TableName:                 aws.String(TableChannelSessions),
		ScanIndexForward:          aws.Bool(false),
		KeyConditionExpression:    keyCondition,
		ExpressionAttributeValues: conditionAttrValues,
		Limit: aws.Int64(1),
	})

	if err != nil {
		return nil, err
	}

	if len(output.Items) == 0 {
		return nil, nil
	}
	return c.buildModel(output.Items[0])
}

// NewChannelSessionAdapter create new processor.
func NewChannelSessionAdapter(env string, region string) ChannelSessionAdapter {
	creds := helper.NewCredentials(env, region)
	awsConfig := &aws.Config{
		S3ForcePathStyle: aws.Bool(true),
		Credentials:      creds,
		Region:           aws.String(region),
	}

	return &channelSessionAdapter{
		client:               dynamodb.New(session.New(awsConfig)),
		minuteAdapter:        NewMinuteBroadcastAdapter(env, region),
		channelUpdateAdapter: NewChannelUpdateAdapter(env, region),
	}
}

// GetLast return last ChannelSession given channelID and size
//
// Returns models.ChannelSession if found and error
// if something went wrong.
func (c *channelSessionAdapter) GetLastList(ctx context.Context, channelID int64, size int64) ([]*models.ChannelSession, error) {
	keyCondition := aws.String("ChannelID = :channelID")
	conditionAttrValues := map[string]*dynamodb.AttributeValue{
		":channelID": {
			N: aws.String(strconv.FormatInt(channelID, 10)),
		},
	}

	output, err := c.client.QueryWithContext(ctx, &dynamodb.QueryInput{
		TableName:                 aws.String(TableChannelSessions),
		ScanIndexForward:          aws.Bool(false),
		KeyConditionExpression:    keyCondition,
		ExpressionAttributeValues: conditionAttrValues,
		Limit: aws.Int64(size),
	})

	if err != nil {
		return nil, err
	}

	if len(output.Items) == 0 {
		return nil, nil
	}
	list := []*models.ChannelSession{}
	for _, item := range output.Items {
		model, err := c.buildModel(item)
		if err != nil {
			log.WithField("item", item).Error(err)
			continue
		}
		list = append(list, model)
	}

	return list, nil
}

// Update changes EndTime in Sessions DynamoDB
// Returns error if something went wrong.
func (c *channelSessionAdapter) Update(ctx context.Context, obj models.ChannelSession) error {
	_, err := c.client.UpdateItemWithContext(ctx, &dynamodb.UpdateItemInput{
		UpdateExpression: aws.String("SET EndTime = :EndTime, BroadcastIDs = :BroadcastIDs"),
		ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
			":EndTime":      {S: aws.String(obj.EndTime.Format(utils.DbTimeFormat))},
			":BroadcastIDs": {NS: utils.AwsArrayFromInt64List(obj.BroadcastIDs)},
		},
		Key: map[string]*dynamodb.AttributeValue{
			"ChannelID": {N: aws.String(strconv.FormatInt(obj.ChannelID, 10))},
			"StartTime": {S: aws.String(obj.StartTime.Format(utils.DbTimeFormat))},
		},
		TableName: aws.String(TableChannelSessions),
	})

	return err
}

// BatchSave - saves models into DynamoDB under specific key
// defined by channel_id and start_time, end_time, broadcast_ids defined in model.
func (c *channelSessionAdapter) BatchSave(models []models.ChannelSession) error {
	if len(models) == 0 {
		return nil
	}

	input := &dynamodb.BatchWriteItemInput{
		RequestItems: map[string][]*dynamodb.WriteRequest{
			TableChannelSessions: make([]*dynamodb.WriteRequest, len(models)),
		},
		ReturnConsumedCapacity: aws.String(dynamodb.ReturnConsumedCapacityTotal),
	}

	for i, model := range models {
		//model.Time
		item := map[string]*dynamodb.AttributeValue{
			"ChannelID": &dynamodb.AttributeValue{N: aws.String(strconv.FormatInt(model.ChannelID, 10))},
			"StartTime": &dynamodb.AttributeValue{S: aws.String(model.StartTime.Format(utils.DbTimeFormat))},
			"EndTime":   &dynamodb.AttributeValue{S: aws.String(model.EndTime.Format(utils.DbTimeFormat))},
		}

		if len(model.BroadcastIDs) > 0 {
			item["BroadcastIDs"] = &dynamodb.AttributeValue{NS: utils.AwsArrayFromInt64List(model.BroadcastIDs)}
		}

		input.RequestItems[TableChannelSessions][i] = &dynamodb.WriteRequest{
			PutRequest: &dynamodb.PutRequest{
				Item: item,
			},
		}
	}

	_, err := c.client.BatchWriteItem(input)

	return err
}

// ExtendOrCreate given MinuteBroadcast model
// extending last session or creating a new one
// return session that was affected.
//
// Returns models.ChannelSession if found and error
// if something went wrong.
func (c *channelSessionAdapter) ExtendOrCreate(model models.MinuteBroadcast) (*models.ChannelSession, error) {
	// Reading last session
	keyCondition := aws.String("ChannelID = :channelID")
	conditionAttrValues := map[string]*dynamodb.AttributeValue{
		":channelID": {
			N: aws.String(strconv.FormatInt(model.ChannelID, 10)),
		},
	}

	output, err := c.client.Query(&dynamodb.QueryInput{
		TableName:                 aws.String(TableChannelSessions),
		ScanIndexForward:          aws.Bool(false),
		KeyConditionExpression:    keyCondition,
		ExpressionAttributeValues: conditionAttrValues,
		Limit: aws.Int64(2),
	})

	if err != nil {
		return nil, err
	}

	// Fist session!
	if len(output.Items) == 0 {
		session := &models.ChannelSession{
			ChannelID:    model.ChannelID,
			StartTime:    model.Time,
			EndTime:      model.Time,
			BroadcastIDs: []int64{model.BroadcastID},
		}
		err := c.addChannelSession(*session)
		if err != nil {
			return nil, err
		}

		return session, nil
	}

	lastSession, err := c.buildModel(output.Items[0])
	if err != nil {
		return nil, err
	}

	// Create new session if last one was too old
	if model.Time.Sub(lastSession.EndTime) > SessionMinuteGap*time.Minute {
		session := &models.ChannelSession{
			ChannelID:    model.ChannelID,
			StartTime:    model.Time,
			EndTime:      model.Time,
			BroadcastIDs: []int64{model.BroadcastID},
		}
		err := c.addChannelSession(*session)
		if err != nil {
			return nil, err
		}
		return session, nil
	}
	// This means that a race condition occurred we should log it
	if model.Time.Sub(lastSession.StartTime) < 0 {
		log.Info("sessions: possible session race condition")
		if len(output.Items) > 1 {
			previous, err := c.buildModel(output.Items[1])
			if err != nil {
				return nil, nil
			}

			totalGap := lastSession.StartTime.Sub(previous.EndTime)
			prevGap := model.Time.Sub(previous.EndTime)
			currentGap := lastSession.StartTime.Sub(model.Time)

			log.Info(fmt.Sprintf("previous end: %s || current start: %s || event time: %s", previous.EndTime.Format(time.RFC3339), lastSession.StartTime.Format(time.RFC3339), model.Time.Format(time.RFC3339)))
			log.Info(fmt.Sprintf("prev/current gap: %v || prev/event gap: %v || current/event gap: %v", totalGap, prevGap, currentGap))
			log.Info(fmt.Sprintf("channel id: %d", model.ChannelID))
		}

		return nil, nil
	}

	lastSession.EndTime = model.Time
	lastSession.BroadcastIDs = utils.AppendUniqueInt64(lastSession.BroadcastIDs, model.BroadcastID)

	err = c.updateChannelSession(*lastSession)
	if err != nil {
		return nil, err
	}

	return lastSession, nil
}

func (c *channelSessionAdapter) buildModel(value map[string]*dynamodb.AttributeValue) (*models.ChannelSession, error) {
	channelID, err := strconv.ParseInt(*value["ChannelID"].N, 10, 64)
	if err != nil {
		return nil, err
	}

	startTime, err := time.Parse(utils.DbTimeFormat, *value["StartTime"].S)
	if err != nil {
		return nil, err
	}

	endTime, err := time.Parse(utils.DbTimeFormat, *value["EndTime"].S)
	if err != nil {
		return nil, err
	}
	var broadcastIDs = []int64{}
	if value["BroadcastIDs"] != nil {
		broadcastIDs = utils.Int64ListFromAwsArray((*value["BroadcastIDs"]).NS)
	}

	return &models.ChannelSession{
		ChannelID:    channelID,
		StartTime:    startTime,
		EndTime:      endTime,
		BroadcastIDs: broadcastIDs,
	}, nil
}

// addChannelSession add to CbChannelSessions
func (c *channelSessionAdapter) addChannelSession(obj models.ChannelSession) error {
	_, err := c.client.PutItem(&dynamodb.PutItemInput{
		Item: map[string]*dynamodb.AttributeValue{
			"ChannelID":    &dynamodb.AttributeValue{N: aws.String(strconv.FormatInt(obj.ChannelID, 10))},
			"StartTime":    &dynamodb.AttributeValue{S: aws.String(obj.StartTime.Format(utils.DbTimeFormat))},
			"EndTime":      &dynamodb.AttributeValue{S: aws.String(obj.EndTime.Format(utils.DbTimeFormat))},
			"BroadcastIDs": &dynamodb.AttributeValue{NS: utils.AwsArrayFromInt64List(obj.BroadcastIDs)},
		},
		TableName: aws.String(TableChannelSessions),
	})

	return err
}

func (c *channelSessionAdapter) updateChannelSession(obj models.ChannelSession) error {
	_, err := c.client.UpdateItem(&dynamodb.UpdateItemInput{
		UpdateExpression: aws.String("SET EndTime = :EndTime, BroadcastIDs = :BroadcastIDs"),
		ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
			":EndTime":      {S: aws.String(obj.EndTime.Format(utils.DbTimeFormat))},
			":BroadcastIDs": {NS: utils.AwsArrayFromInt64List(obj.BroadcastIDs)},
		},
		Key: map[string]*dynamodb.AttributeValue{
			"ChannelID": {N: aws.String(strconv.FormatInt(obj.ChannelID, 10))},
			"StartTime": {S: aws.String(obj.StartTime.Format(utils.DbTimeFormat))},
		},
		TableName: aws.String(TableChannelSessions),
	})

	return err
}

// GetSessionsByTimeForPromotions return and array of Sessions given
// channelID and startTime and endTime.
// Returns []models.ChannelSession if found and error
// This method is exclusive for promotions because we need to find overlapping sessions so we use
// We will use the logic of session end time is after (>) promotion start time & session start time is before (<) promotion end time.
// if something went wrong.
func (c *channelSessionAdapter) GetSessionsByTimeForPromotions(ctx context.Context, channelID int64, startTime time.Time, endTime time.Time) ([]models.ChannelSession, error) {
	var exclusiveStartKey map[string]*dynamodb.AttributeValue
	keyCondition := aws.String("ChannelID = :channelID AND #StartTime < :endTime")
	filterExpression := aws.String("#EndTime > :startTime")
	conditionAttrValues := map[string]*dynamodb.AttributeValue{
		":channelID": {
			N: aws.String(strconv.FormatInt(channelID, 10)),
		},
		":endTime": {
			S: aws.String(endTime.Format(utils.DbTimeFormat)),
		},
		":startTime": {
			S: aws.String(startTime.Format(utils.DbTimeFormat)),
		},
	}
	attributePlaceholders := map[string]*string{
		"#StartTime": aws.String("StartTime"),
		"#EndTime":   aws.String("EndTime"),
	}

	result := []models.ChannelSession{}

	for {
		output, err := c.client.QueryWithContext(ctx, &dynamodb.QueryInput{
			TableName:                 aws.String(TableChannelSessions),
			ScanIndexForward:          aws.Bool(true),
			KeyConditionExpression:    keyCondition,
			ExpressionAttributeValues: conditionAttrValues,
			ExpressionAttributeNames:  attributePlaceholders,
			ExclusiveStartKey:         exclusiveStartKey,
			FilterExpression:          filterExpression,
		})

		if err != nil {
			return nil, err
		}

		for _, value := range output.Items {
			model, err := c.buildModel(value)
			if err != nil {
				return nil, err
			}

			result = append(result, *model)
		}

		if output.LastEvaluatedKey == nil {
			break
		}

		exclusiveStartKey = output.LastEvaluatedKey
	}

	return result, nil
}
