package awsmetric

import (
	"context"

	"code.justin.tv/video/metrics-middleware/v2/operation"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/aws/session"
)

type Client struct {
	Starter *operation.Starter
}

func (c *Client) AddToSession(sess *session.Session) *session.Session {
	newSession := sess.Copy()

	c.installStatusHandlers(&newSession.Handlers)
	return newSession
}

func (c *Client) installStatusHandlers(hs *request.Handlers) {
	// we make a new context key here so no other code (or instance of a call
	// to this method) can access or stomp on our *Operation.
	key := new(int)

	atStart := func(req *request.Request) {
		if req == nil || req.HTTPRequest == nil || req.Operation == nil {
			return
		}

		ctx := req.HTTPRequest.Context()
		ctx, op := c.Starter.StartOp(ctx, getRequestOp(req))
		ctx = context.WithValue(ctx, key, op)
		req.HTTPRequest = req.HTTPRequest.WithContext(ctx)
	}

	atEnd := func(req *request.Request) {
		if req == nil || req.HTTPRequest == nil || req.Operation == nil {
			return
		}

		ctx := req.HTTPRequest.Context()
		if op, _ := ctx.Value(key).(*operation.Op); op != nil {
			var status operation.Status

			if err, ok := req.Error.(awserr.Error); ok {
				// TODO: start mapping AWS errors to their numeric equivalents
				status.Code = 2 // "unknown"
				status.Message = err.Code()
			} else if req.Error != nil {
				status.Code = 2 // "unknown"
				status.Message = req.Error.Error()
			}
			op.SetStatus(status)
			op.End()
		}
	}

	hs.Sign.PushBack(atStart)
	hs.Complete.PushBack(atEnd)
	hs.Retry.PushBack(func(req *request.Request) { atEnd(req); atStart(req) })
}

func getRequestOp(req *request.Request) operation.Name {
	op := operation.Name{
		Kind:   operation.KindClient,
		Group:  req.ClientInfo.ServiceName,
		Method: req.Operation.Name,
	}
	return op
}
