package zephyr

import (
	"context"
	"strconv"
	"sync/atomic"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
	"github.com/pkg/errors"
	"golang.org/x/sync/errgroup"
)

const (
	sessionTable               = "Sessions"
	videoPlayGeographicalTable = "VideoPlayGeoPerSession"
	videoPlayPlatformTable     = "VideoPlayPlatformPerSession"
	videoPlayReferrerTable     = "VideoPlayReferrerPerSession"
	videoPlayUniqueViewsTable  = "VideoPlayUniquePerSession"
)

// GeoMap represents a channel's geographic view breakdown
type GeoMap struct {
	Time      DynamoTimestamp
	ChannelID int64
	Geo       map[string]int64
}

// PlatformMap represents a channel's platform view breakdown
type PlatformMap struct {
	Time      DynamoTimestamp
	ChannelID int64
	Platform  map[string]int64
}

// ReferralMap represents a channel's referral view breakdown
type ReferralMap struct {
	Time      DynamoTimestamp
	ChannelID int64
	Internal  map[string]int64
	External  map[string]int64
}

// GetSessionsByTime returns an array of sessions that overlap with the provided time range
func (c *Client) GetSessionsByTime(ctx context.Context, channelID int64, startTime time.Time, endTime time.Time) ([]DynamoSession, error) {
	table := sessionTable

	var exclusiveStartKey map[string]*dynamodb.AttributeValue
	keyCondition := aws.String("channel_id = :channelID AND #ST BETWEEN :startTime AND :endTime")
	filterExpression := aws.String("#ET >= :startTime")
	attributeValues := map[string]*dynamodb.AttributeValue{
		":channelID": {
			S: aws.String(strconv.FormatInt(channelID, 10)),
		},
		":startTime": {
			S: aws.String(startTime.Truncate(day).AddDate(0, 0, -2).Format(dynamoDBTimeFormat)),
		},
		":endTime": {
			S: aws.String(endTime.Truncate(day).Format(dynamoDBTimeFormat)),
		},
	}
	attributeNames := map[string]*string{
		"#ST": aws.String("segment_start_time"),
		"#ET": aws.String("segment_end_time"),
	}

	total := []DynamoSession{}
	for {
		sub := []DynamoSession{}
		output, err := c.dynamo.QueryWithContext(ctx, &dynamodb.QueryInput{
			TableName:                 aws.String(table),
			ScanIndexForward:          aws.Bool(true),
			KeyConditionExpression:    keyCondition,
			FilterExpression:          filterExpression,
			ExpressionAttributeValues: attributeValues,
			ExpressionAttributeNames:  attributeNames,
			ExclusiveStartKey:         exclusiveStartKey,
		})

		if err != nil {
			return nil, errors.Wrapf(err, "failed to query %s", table)
		}

		err = dynamodbattribute.UnmarshalListOfMaps(output.Items, &sub)
		if err != nil {
			return nil, errors.Wrap(err, "failed to unmarshal list of sessions")
		}

		total = append(total, sub...)

		if output.LastEvaluatedKey == nil {
			break
		}

		exclusiveStartKey = output.LastEvaluatedKey
	}

	return total, nil
}

// GetVideoPlatformsMap return an array of PlatformMaps given channelID and startTime and endTime.
// Returns []PlatformMap if found and error if something went wrong.
func (c *Client) GetVideoPlatformsMap(ctx context.Context, sessions []DynamoSession) ([]PlatformMap, error) {
	table := videoPlayPlatformTable

	keyCondition := aws.String("segment_id = :segmentID")
	wg, ctx := errgroup.WithContext(ctx)
	result := make([]PlatformMap, len(sessions))

	for idx, session := range sessions {
		idx, session := idx, session
		idInt, err := strconv.ParseInt(session.ChannelID, 10, 64)
		if err != nil {
			return nil, errors.Wrap(err, "failed to parse channel id")
		}

		wg.Go(func() error {
			var unmarshalled dynamoVideoPlayPlatform

			expressionAttributeValues := map[string]*dynamodb.AttributeValue{
				":segmentID": {
					S: aws.String(session.SegmentID),
				},
			}

			output, err := c.dynamo.QueryWithContext(ctx, &dynamodb.QueryInput{
				TableName:                 aws.String(table),
				ScanIndexForward:          aws.Bool(true),
				Limit:                     aws.Int64(1),
				KeyConditionExpression:    keyCondition,
				ExpressionAttributeValues: expressionAttributeValues,
			})
			if err != nil {
				return errors.Wrapf(err, "failed to query %s", table)
			}

			if len(output.Items) > 0 {
				err = dynamodbattribute.UnmarshalMap(output.Items[0], &unmarshalled)
				if err != nil {
					return errors.Wrap(err, "failed to unmarshal video play platform")
				}
			}

			result[idx] = PlatformMap{
				ChannelID: idInt,
				Time: DynamoTimestamp{
					Converted: session.SegmentStartTime.Converted,
				},
				Platform: consolidatePlatforms(unmarshalled.PlatformBreakdown),
			}

			return nil
		})
	}

	if err := wg.Wait(); err != nil {
		return nil, err
	}

	return result, nil
}

// GetVideoGeoMap return and array of GeoMap given channelID and startTime and endTime.
// Returns []GeoMap if found and error if something went wrong.
func (c *Client) GetVideoGeoMap(ctx context.Context, sessions []DynamoSession) ([]GeoMap, error) {
	table := videoPlayGeographicalTable

	keyCondition := aws.String("segment_id = :segmentID")
	wg, ctx := errgroup.WithContext(ctx)
	result := make([]GeoMap, len(sessions))

	for idx, session := range sessions {
		idx, session := idx, session
		idInt, err := strconv.ParseInt(session.ChannelID, 10, 64)
		if err != nil {
			return nil, errors.Wrap(err, "failed to parse channel id")
		}

		wg.Go(func() error {
			var unmarshalled dynamoVideoPlayCountry
			expressionAttributeValues := map[string]*dynamodb.AttributeValue{
				":segmentID": {
					S: aws.String(session.SegmentID),
				},
			}

			output, err := c.dynamo.QueryWithContext(ctx, &dynamodb.QueryInput{
				TableName:                 aws.String(table),
				ScanIndexForward:          aws.Bool(true),
				Limit:                     aws.Int64(1),
				KeyConditionExpression:    keyCondition,
				ExpressionAttributeValues: expressionAttributeValues,
			})
			if err != nil {
				return errors.Wrapf(err, "failed to query %s", table)
			}

			if len(output.Items) > 0 {
				err = dynamodbattribute.UnmarshalMap(output.Items[0], &unmarshalled)
				if err != nil {
					return errors.Wrap(err, "failed to unmarshal video play country")
				}
			}

			result[idx] = GeoMap{
				ChannelID: idInt,
				Time: DynamoTimestamp{
					Converted: session.SegmentStartTime.Converted,
				},
				Geo: unmarshalled.CountryBreakdown,
			}

			return nil
		})
	}

	if err := wg.Wait(); err != nil {
		return nil, err
	}

	return result, nil
}

// GetVideoReferralMap return and array of ReferralMap given channelID and startTime and endTime.
// Returns []ReferralMap if found and error if something went wrong.
func (c *Client) GetVideoReferralMap(ctx context.Context, sessions []DynamoSession) ([]ReferralMap, error) {
	table := videoPlayReferrerTable

	keyCondition := aws.String("segment_id = :segmentID")
	wg, ctx := errgroup.WithContext(ctx)
	result := make([]ReferralMap, len(sessions))

	for idx, session := range sessions {
		idx, session := idx, session
		idInt, err := strconv.ParseInt(session.ChannelID, 10, 64)
		if err != nil {
			return nil, errors.Wrap(err, "failed to parse channel id")
		}

		wg.Go(func() error {
			var unmarshalled dynamoVideoPlayReferrer
			expressionAttributeValues := map[string]*dynamodb.AttributeValue{
				":segmentID": {
					S: aws.String(session.SegmentID),
				},
			}

			output, err := c.dynamo.QueryWithContext(ctx, &dynamodb.QueryInput{
				TableName:                 aws.String(table),
				ScanIndexForward:          aws.Bool(true),
				Limit:                     aws.Int64(1),
				KeyConditionExpression:    keyCondition,
				ExpressionAttributeValues: expressionAttributeValues,
			})
			if err != nil {
				return errors.Wrapf(err, "failed to query %s", table)
			}

			if len(output.Items) > 0 {
				err = dynamodbattribute.UnmarshalMap(output.Items[0], &unmarshalled)
				if err != nil {
					return errors.Wrap(err, "failed to unmarshal video play referrer")
				}
			}

			result[idx] = ReferralMap{
				ChannelID: idInt,
				Time: DynamoTimestamp{
					Converted: session.SegmentStartTime.Converted,
				},
				Internal: unmarshalled.InternalBreakdown,
				External: unmarshalled.ExternalBreakdown,
			}

			return nil
		})
	}

	if err := wg.Wait(); err != nil {
		return nil, err
	}

	return result, nil
}

// GetTotalViews returns the total amount of raw video plays across all sessions
func (c *Client) GetTotalViews(ctx context.Context, sessions []DynamoSession) (int64, error) {
	table := videoPlayUniqueViewsTable

	keyCondition := aws.String("segment_id = :segmentID")
	wg, ctx := errgroup.WithContext(ctx)
	var counter int64

	for _, session := range sessions {
		session := session

		wg.Go(func() error {
			var unmarshalled dynamoVideoPlayViews
			expressionAttributeValues := map[string]*dynamodb.AttributeValue{
				":segmentID": {
					S: aws.String(session.SegmentID),
				},
			}

			output, err := c.dynamo.QueryWithContext(ctx, &dynamodb.QueryInput{
				TableName:                 aws.String(table),
				ScanIndexForward:          aws.Bool(true),
				Limit:                     aws.Int64(1),
				KeyConditionExpression:    keyCondition,
				ExpressionAttributeValues: expressionAttributeValues,
			})
			if err != nil {
				return errors.Wrapf(err, "failed to query %s", table)
			}

			if len(output.Items) > 0 {
				err = dynamodbattribute.UnmarshalMap(output.Items[0], &unmarshalled)
				if err != nil {
					return errors.Wrap(err, "failed to unmarshal video play views")
				}

				atomic.AddInt64(&counter, unmarshalled.TotalViews)
			}

			return nil
		})
	}

	if err := wg.Wait(); err != nil {
		return 0, err
	}

	return atomic.LoadInt64(&counter), nil
}
