package awsutils

import (
	"context"

	"github.com/aws/aws-sdk-go/aws/request"
	"golang.org/x/net/trace"
)

// InstallXTraceHandlers adds callbacks to the AWS SDK to track individual
// HTTP requests made to AWS APIs.
//
// Logs and duration histograms will be available at /debug/requests on the
// default ServeMux.
//
// Some service clients (including Kinesis) install their own handlers when
// instantiating the per-service client. When interacting with those services,
// call this function on the resulting client's Handlers rather than on a
// session's Handlers.
func InstallXTraceHandlers(hs *request.Handlers) {
	hs.Send.PushFront(awsCallSendHeaders)
	hs.Send.PushBack(awsCallReadHeaders)
	hs.UnmarshalError.PushBack(awsCallReadBody)
	hs.Unmarshal.PushBack(awsCallReadBody)
}

var spanKey = new(int)

func awsCallSendHeaders(req *request.Request) {
	if req == nil || req.HTTPRequest == nil || req.Operation == nil {
		return
	}
	ctx := req.HTTPRequest.Context()

	op := "aws." + req.ClientInfo.ServiceName + "." + req.Operation.Name
	tr := trace.New(op, "")
	ctx = trace.NewContext(ctx, tr)
	ctx = context.WithValue(ctx, spanKey, tr)
	req.HTTPRequest = req.HTTPRequest.WithContext(ctx)
}

func awsCallReadHeaders(req *request.Request) {
	if req == nil || req.HTTPRequest == nil {
		return
	}
	ctx := req.HTTPRequest.Context()

	tr, ok := ctx.Value(spanKey).(trace.Trace)
	if !ok {
		return
	}

	if req.HTTPResponse != nil {
		tr.LazyPrintf("status=%d", req.HTTPResponse.StatusCode)
	}

	if req.Error != nil {
		tr.SetError()
		tr.LazyPrintf("error: %s", req.Error)
		tr.Finish()
		tr = nil
		req.HTTPRequest = req.HTTPRequest.WithContext(trace.NewContext(ctx, tr))
	}
}

func awsCallReadBody(req *request.Request) {
	if req == nil || req.HTTPRequest == nil {
		return
	}
	ctx := req.HTTPRequest.Context()

	tr, ok := ctx.Value(spanKey).(trace.Trace)
	if !ok {
		return
	}

	if req.Error != nil {
		tr.SetError()
		tr.LazyPrintf("error: %s", req.Error)
	}

	tr.Finish()
	tr = nil
	req.HTTPRequest = req.HTTPRequest.WithContext(trace.NewContext(ctx, tr))
}
