/* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. */

// package roundtrip centralizes the round trip logic that is common
// across the various codecs.
package roundtrip

import (
	"CoralGoCodec/codec"
	"CoralRPCGoSupport/rpc"
	"aaa"
	"authv4"
	"authv4/arps"
	"bufio"
	"bytes"
	"io"
	"io/ioutil"
	"net/http"
	"time"
)

const (
	headerAmznDate      = "X-Amz-Date"
	headerAmznToken     = "X-Amz-Security-Token"
	headerAmznRequestId = "X-Amzn-RequestId"
	headerAuthorization = "Authorization"
)

// Codec is an extension of codec.Codec that adds replaces Unmarshal with
// UnmarshalWithService to improve type detection.
type Codec interface {
	// MarshalRequest turns the given request into a byte slice.
	MarshalRequest(r *codec.Request) ([]byte, error)

	// UnmarshalResponse unmarshals the given data into the given obj
	// using assemblies from the given service.
	UnmarshalResponse(rtCtx Context, data []byte, obj interface{}, service string) error
}

// BearerTokenVendor vends a token used for HTTP's Bearer authentication
// scheme.
//
// If a non-empty token is vended, the header below will be inserted
// into the HTTP request:
//	Authorization: Bearer <token>
type BearerTokenVendor interface {
	// Vend returns a token and optional error given a hostname. A hostname
	// is a host's FQDN along with an optional port number delimited by a
	// colon ':'
	Vend(hostname string) (string, error)
}

// Tripper contains everything necessary to perform a RoundTrip request.
type Tripper struct {
	Codec              Codec                  // required
	Host               string                 // optional
	RequestIdGenerator rpc.RequestIdGenerator // optional
	AAAClient          aaa.Client             // optional
	ARPSAuthorizer     *arps.ARPSAuthorizer   // optional
	BasicAuth          *rpc.BasicAuth         // optional
	BearerTokenVendor  BearerTokenVendor      // optional
	SecurityToken      string                 // optional, see https://w.amazon.com/index.php/Coral/Specifications/HttpSecurityToken
	SignerV4           *authv4.Signer         // optional
}

type Context struct {
	Request            *codec.Request
	RequestID          string
	ResponseStatusCode int
}

// Go performs a round trip for the given request, writing to and then reading
// from the given io.ReadWriter.
func (c *Tripper) Go(path string, r *codec.Request, rw io.ReadWriter) error {
	// TODO: Back-port "CloudAuth resource server support"
	// https://code.amazon.com/packages/CoralRPCGoSupport/commits/03caf13b07b8b70f2ecb8f5863395acb0dc322b3

	if path == "" {
		path = "/"
	}

	rtCtx := Context{Request: r}

	b, errMarshal := c.Codec.MarshalRequest(r)
	if errMarshal != nil {
		return NewError("failed to marshal input", errMarshal, rtCtx)
	}

	body := bytes.NewBuffer(b)
	request, errRequest := http.NewRequest("POST", path, body)
	if errRequest != nil {
		return NewError("failed to create new request", errRequest, rtCtx)
	}

	if c.Host != "" {
		request.Host = c.Host
		request.URL.Host = c.Host
	}

	for k, v := range r.RequestHeaders {
		request.Header.Set(k, v)
	}

	if c.SecurityToken != "" {
		request.Header.Set(headerAmznToken, c.SecurityToken)
	}

	if c.BasicAuth != nil {
		request.SetBasicAuth(c.BasicAuth.Username, c.BasicAuth.Password)
	}

	if c.BearerTokenVendor != nil {
		// Get a Bearer token and set the appropriate Authorization header.
		token, errBearer := c.BearerTokenVendor.Vend(request.Host)
		if errBearer != nil {
			return NewError("failed to retrieve bearer token", errBearer, rtCtx)
		}
		if token != "" {
			request.Header.Set(headerAuthorization, "Bearer "+token)
		}
	}

	// Use v4 signing if we have it, otherwise just add the Amazon date header.
	if c.SignerV4 != nil {
		if errSign := c.SignerV4.Sign(request); errSign != nil {
			return NewError("failed to sign request", errSign, rtCtx)
		}
	} else {
		// v4 signing will add this header
		request.Header.Set(headerAmznDate, time.Now().UTC().Format(time.RFC822))
	}

	var clientCxt *aaa.ClientContext
	if c.AAAClient != nil {
		var errEncode error
		clientCxt, errEncode = c.AAAClient.EncodeRequest(r.Service.ShapeName, r.Operation.ShapeName, request)
		if errEncode != nil {
			return NewError("failed to AAA encode request", errEncode, rtCtx)
		}
	}

	if c.RequestIdGenerator != nil {
		reqId, errRid := c.RequestIdGenerator()
		if errRid != nil {
			return NewError("failed to generate request id header", errRid, rtCtx)
		}
		request.Header.Set(headerAmznRequestId, reqId)
		rtCtx.RequestID = reqId
	}

	if errWrite := request.Write(rw); errWrite != nil {
		return NewError("failed to write request", errWrite, rtCtx)
	}

	resp, errResponse := http.ReadResponse(bufio.NewReader(rw), request)
	if errResponse != nil {
		return NewError("failed to read http response", errResponse, rtCtx)
	}

	rtCtx.ResponseStatusCode = resp.StatusCode

	if c.AAAClient != nil {
		if errDecode := c.AAAClient.DecodeResponse(clientCxt, resp); errDecode != nil {
			return NewError("failed to AAA decode response", errDecode, rtCtx)
		}
	}

	if resp.Body == nil {
		return nil
	}

	defer resp.Body.Close()
	respBody, errBody := ioutil.ReadAll(resp.Body)
	if errBody != nil {
		return NewError("failed to read response body", errBody, rtCtx)
	}

	// Odin returns empty body with 200 sometimes which breaks Unmarshal
	if len(respBody) == 0 || r.Output == nil {
		return nil
	}

	return c.Codec.UnmarshalResponse(rtCtx, respBody, r.Output, r.Service.ShapeName)
}
