package svc

import (
	"context"
	"fmt"
	"time"

	"github.com/golang/protobuf/jsonpb"
	"github.com/twitchtv/twirp"

	control "code.justin.tv/event-engineering/carrot-analytics/control/rpc"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/dynamodb"
	"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
	"github.com/aws/aws-sdk-go/service/s3"
	"github.com/aws/aws-sdk-go/service/sqs"
	"github.com/golang/protobuf/ptypes"
	"github.com/google/uuid"
	"github.com/sirupsen/logrus"
)

// Retain queries for 1 week, we'll also delete the results from S3 on the same schedule
const queryTTL = time.Hour * 24 * 7

// Client defines the functions that will be available in this service, in this case it's pretty much a straight implementation of the twirp service
type Client interface {
	EnqueueQuery(context context.Context, request *control.EnqueueQueryRequest) (*control.EnqueueQueryResponse, error)
	GetQueryResult(context context.Context, request *control.GetQueryResultRequest) (*control.GetQueryResultResponse, error)
	ListQueries(context context.Context, request *control.ListQueriesRequest) (*control.ListQueriesResponse, error)
}

type client struct {
	executorSQSQueueURL string
	queriesTableName    string
	resultsBucketName   string
	ddb                 *dynamodb.DynamoDB
	sqs                 *sqs.SQS
	s3                  *s3.S3
	logger              logrus.FieldLogger
}

// New returns a new Carrot Analytics Control client
func New(sess *session.Session, executorSQSQueueURL, queriesTableName, resultsBucketName string, logger logrus.FieldLogger) Client {
	return &client{
		executorSQSQueueURL: executorSQSQueueURL,
		queriesTableName:    queriesTableName,
		resultsBucketName:   resultsBucketName,
		ddb:                 dynamodb.New(sess),
		sqs:                 sqs.New(sess),
		s3:                  s3.New(sess),
		logger:              logger,
	}
}

type queryRecord struct {
	QueryID     string              `json:"query_id"`
	Query       string              `json:"query"`
	RequestedBy string              `json:"requested_by"`
	RequestedAt time.Time           `json:"requested_at"`
	ResultPath  string              `json:"result_path"`
	Status      control.QueryStatus `json:"status"`
	TTL         int64               `json:"ttl"`
	Label       string              `json:"label"`
}

func (c *client) EnqueueQuery(context context.Context, request *control.EnqueueQueryRequest) (*control.EnqueueQueryResponse, error) {
	// TODO: It might be a good idea to check recent queries for any matching this one so that we dont rerun queries unnecessarily
	// Also we can see if there are any queued queries with the same parameters from the same user and just return that report ID
	// This will prevent things like double click spam

	// Add query to dynamodb
	queryID := uuid.New().String()
	now := time.Now()

	// Gonna marshal this here so we bomb out early if it fails for some reason
	m := jsonpb.Marshaler{}
	payload, err := m.MarshalToString(request.Query)
	if err != nil {
		return nil, err
	}

	resultPath := fmt.Sprintf("%s/%s", request.RequestedBy, queryID)

	rec := &queryRecord{
		QueryID:     queryID,
		Query:       payload,
		RequestedBy: request.RequestedBy,
		RequestedAt: now,
		ResultPath:  resultPath,
		Status:      control.QueryStatus_QUEUED,
		TTL:         now.Add(queryTTL).Unix(),
		Label:       request.Label,
	}

	data, err := dynamodbattribute.MarshalMap(rec)
	if err != nil {
		return nil, err
	}

	_, err = c.ddb.PutItem(&dynamodb.PutItemInput{
		TableName: &c.queriesTableName,
		Item:      data,
	})

	if err != nil {
		return nil, err
	}

	// Send query details to SQS queue
	_, err = c.sqs.SendMessage(&sqs.SendMessageInput{
		QueueUrl: &c.executorSQSQueueURL,
		MessageAttributes: map[string]*sqs.MessageAttributeValue{
			"query_id": &sqs.MessageAttributeValue{
				DataType:    aws.String("String"),
				StringValue: &queryID,
			},
			"result_path": &sqs.MessageAttributeValue{
				DataType:    aws.String("String"),
				StringValue: &resultPath,
			},
		},
		MessageBody: aws.String(string(payload)),
	})

	if err != nil {
		rec.Status = control.QueryStatus_FAILURE
		sqsErr := err

		data, err := dynamodbattribute.MarshalMap(rec)
		if err != nil {
			c.logger.WithError(err).Warn("Failed to marshal query update after SQS send failure")
		} else {
			_, err = c.ddb.PutItem(&dynamodb.PutItemInput{
				TableName: &c.queriesTableName,
				Item:      data,
			})

			if err != nil {
				c.logger.WithError(err).Warn("Failed to update query in dynamo after SQS send failure")
			}
		}

		return nil, sqsErr
	}

	// Return query ID
	return &control.EnqueueQueryResponse{
		QueryId: queryID,
	}, nil
}

func (c *client) ListQueries(context context.Context, request *control.ListQueriesRequest) (*control.ListQueriesResponse, error) {
	// We're just going to grab one page for now, we'll build pagination into this... at some point
	resp, err := c.ddb.Query(&dynamodb.QueryInput{
		TableName:              &c.queriesTableName,
		IndexName:              aws.String("queries-requested_by"),
		KeyConditionExpression: aws.String("requested_by = :requested_by"),
		ScanIndexForward:       aws.Bool(false),
		ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
			":requested_by": &dynamodb.AttributeValue{
				S: &request.RequestedBy,
			},
		},
	})

	if err != nil {
		return nil, err
	}

	var queries []*queryRecord

	err = dynamodbattribute.UnmarshalListOfMaps(resp.Items, &queries)

	if err != nil {
		return nil, err
	}

	result := &control.ListQueriesResponse{
		Queries: make([]*control.QuerySummary, 0, len(queries)),
	}

	for _, qr := range queries {
		requestedAtProto, err := ptypes.TimestampProto(qr.RequestedAt)
		if err != nil {
			return nil, err
		}

		result.Queries = append(result.Queries, &control.QuerySummary{
			QueryId:     qr.QueryID,
			Status:      qr.Status,
			Ttl:         qr.TTL,
			RequestedAt: requestedAtProto,
			Label:       qr.Label,
		})
	}

	return result, nil
}

func (c *client) GetQueryResult(context context.Context, request *control.GetQueryResultRequest) (*control.GetQueryResultResponse, error) {
	logger := c.logger.WithField("Endpoint", "GetQueryResult")

	keyAttr, err := dynamodbattribute.Marshal(request.QueryId)
	if err != nil {
		return nil, err
	}

	resp, err := c.ddb.GetItem(&dynamodb.GetItemInput{
		TableName: &c.queriesTableName,
		Key: map[string]*dynamodb.AttributeValue{
			"query_id": keyAttr,
		},
	})

	// TODO work out how to handle missing items
	if err != nil {
		logger.WithError(err).Debugf("Could not find query with id %v", request.QueryId)
		return nil, twirp.NotFoundError("Query not found").WithMeta("Code", "NOT_FOUND")
	}

	var qr queryRecord
	err = dynamodbattribute.UnmarshalMap(resp.Item, &qr)
	if err != nil {
		logger.WithError(err).Debugf("Could not unmarshall query record with id %v", request.QueryId)
		return nil, twirp.NotFoundError("Query not found").WithMeta("Code", "INVALID_DATA")
	}

	if qr.Status == control.QueryStatus_SUCCESS {
		result, err := c.s3.GetObject(&s3.GetObjectInput{
			Bucket: &c.resultsBucketName,
			Key:    &qr.ResultPath,
		})

		if err != nil {
			logger.WithError(err).Debugf("Could not retrieve query data, id %v", request.QueryId)
			return nil, twirp.NotFoundError("Query not found").WithMeta("Code", "FAILED_TO_RETRIEVE")
		}

		var queryResult control.GetQueryResultResponse

		m := jsonpb.Unmarshaler{
			AllowUnknownFields: true,
		}

		err = m.Unmarshal(result.Body, &queryResult)

		if err != nil {
			logger.WithError(err).Debugf("Could not unmarshall query data, id %v", request.QueryId)
			return nil, twirp.NotFoundError("Query not found").WithMeta("Code", "INVALID_DATA")
		}

		return &queryResult, nil
	}

	if qr.Status == control.QueryStatus_QUEUED || qr.Status == control.QueryStatus_EXECUTING {
		return nil, twirp.NotFoundError("Query not found").WithMeta("Code", "AWAITING")
	}

	return nil, twirp.NotFoundError("Query not found").WithMeta("Code", "FAILED")
}
