/* Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. */

package rpcv1

import (
	"CoralGoCodec/codec"
	"CoralGoModel/model"
	"CoralRPCGoSupport/internal/roundtrip"
	cjson "CoralRPCGoSupport/rpc/encoding/json"
	"GoLog/log"
	"aaa"
	"io"
	"io/ioutil"
	"net/http"
	"reflect"
	"strings"

	"github.com/pkg/errors"
	"golang.a2z.com/cloudauth"
)

const (
	headerAccept            = "Accept"
	headerAmznTarget        = "X-Amz-Target"
	headerAmznAuthorization = aaa.AaaAuthHeader
	headerContentType       = "Content-Type"
	headerContentEncoding   = "Content-Encoding"

	accept          = "application/json, text/javascript"
	contentType     = "application/json"
	contentTypeFull = "application/json; charset=UTF-8;"
	contentEncoding = "amz-1.0"
)

// Marshal supports codec.Codec.
func (c RPCv1) Marshal(obj interface{}) ([]byte, error) {
	return cjson.Marshal(obj)
}

// Unmarshal supports codec.Codec.
func (c RPCv1) Unmarshal(d []byte, obj interface{}) error {
	return cjson.Unmarshal(d, obj, "")
}

// UnmarshalWithService unmarshals the given data into the given obj
// using assemblies from the given service.
func (c RPCv1) UnmarshalWithService(m map[string]interface{}, obj interface{}, service string) error {
	return cjson.UnmarshalMap(m, obj, service)
}

// IsSupported supports codec.Codec.
// POST http://0.0.0.0:8000/
//
// HTTP/1.1
// Content-Type: application/json; charset=UTF-8
// Content-Encoding: amz-1.0
// X-Amz-Date: Thu, 01 Mar 2012 22:07:13 GMT
// X-Amz-Target: com.amazon.coral.demo.WeatherService.GetWeather
//
// {"__type": "com.amazon.coral.demo.WeatherService#GetWeatherInput", "location": "foo"}
func (c RPCv1) IsSupported(g codec.Getter) bool {
	if !strings.Contains(g.Get(headerContentType), contentType) {
		log.Trace("RPCv1: Content type doesn't match", g.Get(headerContentType), contentType)
		return false
	}

	// If the Content-Encoding is present, then it needs to be equal
	// to our value.  Content-Encoding is not required.
	enc := g.Get(headerContentEncoding)
	if enc != contentEncoding {
		log.Trace("RPCv1: contentEncoding doesn't match", enc, contentEncoding)
		return false
	}

	// We don't care about these values, but we care that they are present.
	assembly, service, op := GetServiceAndOp(g.Get(headerAmznTarget))
	if assembly == "" || service == "" || op == "" {
		log.Trace("RPCv1: Unable to determine assembly, service or op")
		return false
	}

	return true
}

// GetServiceAndOp splits the given val based on the location of the last period character.
// If there is no period character, then empty strings are returned.
func GetServiceAndOp(val string) (assembly, service, operation string) {
	if i := strings.LastIndex(val, "."); i > 0 && i < len(val)-1 {
		service, operation = val[:i], val[i+1:]
	}
	if i := strings.LastIndex(service, "."); i > 0 && i < len(service)-1 {
		assembly, service = service[:i], service[i+1:]
		return
	}
	return "", "", ""
}

// cloudAuthContext is used to track whether or not a particular request requires
// a bearerChallenge before it will be allowed to be retried.
type cloudAuthContext struct {
	bearerChallenge string
}

// UnmarshalRequest satisfies codec.Server.  It is possible to return a pointer to a codec
// request and an error in the scenario where a CloudAuth authorization failure occurs
// and a Bearer Challenge is necessary. The Codec Request will contain the Bearer Challenge
// which is used when Marshalling the response.
func (c RPCv1) UnmarshalRequest(r *http.Request) (*codec.Request, error) {
	var sctx *aaa.ServiceContext
	var err error

	// Since there is the potential to have two different authorization modes, AAA and CloudAuth,
	// active for a service at the same time we need to be able to distinguish when a request is
	// meant for one vs the other.  We use the presence of the "x-amzn-Authorization" header to
	// signal that a request is using AAA. If the request is using CloudAuth it will instead
	// contain the "Authorization" header.
	headerAmznAuthorization := headerAmznAuthorization
	hasAAAHeader := r.Header.Get(headerAmznAuthorization) != ""
	if c.AAA != nil && hasAAAHeader {
		if sctx, err = c.AAA.DecodeRequest(r); err != nil {
			return nil, err
		}
		log.Tracef("Decoded ServiceContext %#v", *sctx)
	}

	// TODO: Authentication and authorization for sigv4 must be split and put in interfaces.
	if c.ARPSAuthorizer != nil {
		auth, err := c.ARPSAuthorizer.Authenticate(r)
		if err != nil {
			return nil, errors.Errorf("request is not authorized. error message is %v", err)
		}
		/*
			TODO: This message must be more informative(contain user principal instead of just account Id)
			once Authentication and authorization for sigv4 are split.
		*/
		log.Tracef("Request from account %s is authorized", auth.AccountId())
	}

	asmName, serviceName, opName := GetServiceAndOp(r.Header.Get(headerAmznTarget))
	if serviceName == "" || opName == "" {
		return nil, errors.New("Request is not supported by the RPCv1 codec")
	}

	defer r.Body.Close()
	reqBody, err := ioutil.ReadAll(r.Body)
	if err != nil {
		return nil, err
	}

	asm := model.LookupService(serviceName).Assembly(asmName)
	op, err := asm.Op(opName)
	if err != nil {
		return nil, errors.Wrap(err, "Unable to retrieve operation "+opName)
	}

	// Make sure that the auth context has the service and operation name from the request.
	if sctx != nil && sctx.Service == "" {
		sctx.Service = serviceName
		sctx.Operation = opName
	}

	// Create the base codec.Request, then create instances of the input
	// and output if they are defined for the operation.
	cr := &codec.Request{
		Service:   codec.ShapeRef{AsmName: asmName, ShapeName: serviceName},
		Operation: codec.ShapeRef{AsmName: asmName, ShapeName: opName},
		AuthCtx:   sctx,
	}

	// Now that we have all of the information we need to authorize the request, perform
	// authorization before dispatching to the service.
	// If an error occurs the Codec Request will be returned as nil.
	// In the case of authorizing CloudAuth requests,
	// we require context from the request to generate a challenge response.
	// In this scenario we return the codec request even when an error occurs.
	cr, err = c.authorizeRequest(cr, r, hasAAAHeader)
	if err != nil {
		return cr, err
	}

	if input := op.Input(); input != nil {
		m, mapErr := cjson.ToMap(reqBody)
		if mapErr != nil {
			return nil, mapErr
		}
		targetType := reflect.TypeOf(input.New()).Elem()
		shape := cjson.DetermineShape(reflect.ValueOf(m), targetType, cr.Service.ShapeName)
		cr.Input = shape.New()
		if err := c.UnmarshalWithService(m, cr.Input, cr.Service.ShapeName); err != nil {
			return nil, err
		}
	}
	if output := op.Output(); output != nil {
		cr.Output = output.New()
	}
	return cr, nil
}

// MarshalResponse satisfies codec.Server
func (c RPCv1) MarshalResponse(w http.ResponseWriter, r *codec.Request) {
	if r == nil || r.Operation.ShapeName == "" || r.Service.ShapeName == "" {
		log.Fatal("MarshalResponse called without proper input.")
		http.Error(w, "Unable to process request", http.StatusInternalServerError)
		return
	}

	headers := w.Header()
	if cloudAuthCtx, ok := r.AuthCtx.(*cloudAuthContext); ok && cloudAuthCtx.bearerChallenge != "" {
		// Returning a 401 with a WWW-Authenticate header represents the Bearer Challenge.
		// This process is outlined here https://w.amazon.com/bin/view/Dev.CDO/UnifiedAuth/CloudAuth/Design/#HFlow
		w.Header().Set("WWW-Authenticate", cloudAuthCtx.bearerChallenge)
		w.WriteHeader(http.StatusUnauthorized)
		return
	}

	var body []byte
	var err error

	// Serialize r.Output into bytes if available.
	if r.Output != nil {
		body, err = c.Marshal(r.Output)
		if err != nil {
			http.Error(w, "Unable to process response", http.StatusInternalServerError)
			return
		}
	}

	// Set the appropriate response headers for the codec.
	headers.Set(headerContentType, contentType)
	headers.Set(headerContentEncoding, contentEncoding)

	if sctx, ok := r.AuthCtx.(*aaa.ServiceContext); ok && c.AAA != nil {
		body, err = c.AAA.EncodeResponse(sctx, headers, body)
		if err != nil {
			log.Fatalf("Error encoding response using AAA: %+v", err)
			http.Error(w, "Unable to encode response", http.StatusInternalServerError)
			return
		}
		log.Trace("Output has been AAA encoded")
	}

	// Send the response.
	if body != nil {
		w.Write(body)
	} else {
		w.WriteHeader(http.StatusNoContent)
	}
}

// MarshalRequest satisfies roundtrip.Codec.
func (c RPCv1) MarshalRequest(r *codec.Request) ([]byte, error) {
	return c.Marshal(r.Input)
}

// UnmarshalResponse satisfies roundtrip.Codec.
func (c RPCv1) UnmarshalResponse(rtCtx roundtrip.Context, respBody []byte, obj interface{}, service string) error {
	m, err := cjson.ToMap(respBody)
	if err != nil {
		return roundtrip.NewError("failed to convert response to a map", err, rtCtx)
	}

	// Return the raw error since the error might be from the model.
	return c.UnmarshalWithService(m, obj, service)
}

// RoundTrip supports codec.RoundTripper.
func (c RPCv1) RoundTrip(r *codec.Request, rw io.ReadWriter) error {
	if r.RequestHeaders == nil {
		r.RequestHeaders = make(map[string]string)
	}
	r.RequestHeaders[headerAccept] = accept
	r.RequestHeaders[headerContentType] = contentTypeFull
	r.RequestHeaders[headerContentEncoding] = contentEncoding
	r.RequestHeaders[headerAmznTarget] = r.Service.AsmName + "." + r.Service.ShapeName + "." + r.Operation.ShapeName
	return c.tripper.Go(c.Path, r, rw)
}

// authorizeRequest performs authorization if the codec was set up to use AAA and/or CloudAuth.
// If neither has been set up then no action is performed.
func (c RPCv1) authorizeRequest(cr *codec.Request, r *http.Request, hasAAAHeader bool) (*codec.Request, error) {
	if c.AAA != nil || c.cloudAuth != nil {
		switch {
		// If AAA was specified and we see the AAA Authorization header, we assume request was made via AAA.
		// Otherwise we treat it as a request that requires CloudAuth
		case c.AAA != nil && hasAAAHeader:
			auth, err := c.AAA.AuthorizeRequest(cr.AuthCtx.(*aaa.ServiceContext))
			if err != nil {
				return nil, err
			}
			if !auth.Authorized {
				return nil, errors.New("Request is not authorized.  Error message is " + auth.ErrorMessage)
			}
			log.Trace("Request is authorized with code", auth.AuthorizationCode)
		case c.cloudAuth != nil:
			cloudAuthContext := &cloudAuthContext{}
			cr.AuthCtx = cloudAuthContext
			auth, err := c.cloudAuth.AuthorizeRequest(r, cr.Service.ShapeName, cr.Operation.ShapeName)
			if err != nil {
				return nil, err
			}
			if auth.Result == cloudauth.ResultDeny || auth.Result == cloudauth.ResultChallenge {
				cloudAuthContext.bearerChallenge = auth.BearerChallenge
				return cr, errors.New("Request is not authorized")
			}
			log.Trace("RPCv0: Request is authorized via CloudAuth")
		default:
			return nil, errors.New("Request is not authorized.")
		}
	}
	return cr, nil
}
