package restclient

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"net/url"
	"reflect"
	"strconv"
	"strings"

	"code.justin.tv/release/trace/rpc/alvin/internal/httpproto"
	"code.justin.tv/release/trace/rpc/alvin/internal/httpproto/loadroutes"
	"code.justin.tv/release/trace/rpc/alvin/internal/httproute"
	"github.com/golang/protobuf/jsonpb"
	"github.com/golang/protobuf/proto"
)

// RoundTripperFromProtos returns a special http.RoundTripper for use with
// generated Twirp clients. The RoundTripper transcodes HTTP requests from
// Twirp's wire format into JSON/HTTP requests based on a set of mapping
// rules. This allows using generated Twirp client code to access a variety of
// non-Twirp APIs, which may be unwilling or unable to provide native Twirp
// endpoints.
//
// The mapping rules are written as annotations in .proto files. The HTTP
// method and URI path to be used for an RPC method are described by
// "google.api.http" annotations to the method. The "json_name" field
// annotation can be used to control the JSON format of requests and the
// expected format of responses, as can the special types defined in
// google/protobuf/struct.proto and google/protobuf/wrappers.proto.
//
// The HTTP mapping rules are described in google/api/http.proto, available in
// the following locations:
//     https://github.com/googleapis/googleapis/blob/master/google/api/http.proto
//     https://git-aws.internal.justin.tv/release/trace/blob/master/rpc/alvin/internal/httpproto/google_api/http.proto
//
// Documentation on the HTTP mapping rules is available as HTML here:
//     https://cloud.google.com/service-management/reference/rpc/google.api#httprule
//
// More information on the JSON mapping rules can be found here:
//     https://developers.google.com/protocol-buffers/docs/proto3#json
//
// The returned RoundTripper is able to handle any RPC methods of any services
// defined in the provided list of protoFilenames. The names of protobuf IDL
// files are relative to the include path used when invoking protoc with the
// --go_out flag, and can usually be found in the init function toward the
// bottom of the generated .pb.go files as the first argument to
// proto.RegisterFile.
//
// When a nil base RoundTripper is provided, the http.DefaultTransport is used
// to send requests over the network.
func RoundTripperFromProtos(base http.RoundTripper, protoFilenames []string) (http.RoundTripper, error) {
	if base == nil {
		base = http.DefaultTransport
	}

	rt := &roundTripper{
		base:   base,
		routes: make(map[string]*routeSpec),
	}

	for _, filename := range protoFilenames {
		methods, err := loadroutes.FromProtoFileName(filename)
		if err != nil {
			return nil, err
		}

		for _, method := range methods {
			if _, ok := rt.routes[method.FullName]; ok {
				return nil, fmt.Errorf("restclient: duplicate entry for method %q in file %q", method.FullName, filename)
			}
			spec, err := newRouteSpec(method)
			if err != nil {
				return nil, err
			}
			rt.routes[method.FullName] = spec
		}
	}

	return rt, nil
}

// newRouteSpec converts the provided method description into one which is
// convenient for later use for serializing and deserializing HTTP requests in
// various formats.
func newRouteSpec(method *httpproto.MethodSpec) (*routeSpec, error) {
	if len(method.Routes) == 0 {
		return nil, fmt.Errorf("restclient: method %q has no HTTP mappings", method.FullName)
	}
	route := method.Routes[0]

	spec := &routeSpec{
		httpMethod: route.HTTPMethod,
		bodyField:  route.BodyField,
	}

	var err error
	if spec.uriTemplate, err = httproute.ParseTemplate(route.PathPattern); err != nil {
		return nil, err
	}

	if spec.input = proto.MessageType(strings.TrimPrefix(method.InputType, ".")); spec.input == nil {
		return nil, fmt.Errorf("restclient: unknown protobuf type %q for method %q", method.InputType, method.FullName)
	}
	if spec.output = proto.MessageType(strings.TrimPrefix(method.OutputType, ".")); spec.output == nil {
		return nil, fmt.Errorf("restclient: unknown protobuf type %q for method %q", method.OutputType, method.FullName)
	}

	return spec, nil
}

type routeSpec struct {
	httpMethod  string
	uriTemplate *httproute.Template
	bodyField   string

	input  reflect.Type
	output reflect.Type
}

// twirpRPCMethod returns the fully-qualified method name represented by the
// provided Twirp URI path. If the method name cannot be determined, it
// returns the empty string.
func twirpRPCMethod(path string) string {
	switch {
	case strings.HasPrefix(path, "/v2/"):
		return strings.TrimPrefix(path, "/v2")
	case strings.HasPrefix(path, "/twirp/"):
		return strings.TrimPrefix(path, "/twirp")
	default:
		return ""
	}
}

type roundTripper struct {
	base   http.RoundTripper
	routes map[string]*routeSpec
}

var _ http.RoundTripper = (*roundTripper)(nil)

// RoundTrip converts Twirp-format HTTP requests into REST-inspired JSON/HTTP
// requests, as described by the mapping rules provided on creation of the
// roundTripper value.
func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
	// Quite a few steps are involved here:
	//
	//   1. Decode the initial RPC from the wire.
	//     A. Determine the logical RPC method from the request.
	//       i. For Twirp requests, remove the "/v2" or "/twirp" path prefix.
	//     B. Determine the expected body type for that method.
	//     C. Deserialize the body into the input protobuf message.
	//
	//   2. Issue a REST-style request.
	//     A. Populate the URI path template using the input message.
	//     B. Populate the request body using the input message.
	//     C. Convert any remaining fields into URI query parameters.
	//     D. Read the response body as JSON into an output struct value.
	//
	//   3. Craft a convincing *http.Response value for the caller.
	//     A. Copy over headers and trailers.
	//     B. Marshal the output message according to expected content-type.
	//       i. Twirp expects this to match the request's content-type.
	//     C. Set the body, content-type, content-length.

	route, input, err := rt.decodeRequest(req)
	if err != nil {
		return nil, err
	}

	// A few additional details on how the input message is handled during the
	// conversion to a REST-style request: Fields referenced in the URI path
	// template must be primitive fields. No field on the path to those fields
	// -- or to the body field if present -- may be a "repeated" field. Fields
	// referenced in the URI path are cleared and the field referenced for the
	// body is cleared, thus modifying the input message.
	//
	// When converting the remaining fields to URI query parameters, proto3
	// fields with their default (zero) value are not serialized. Messages
	// from google/protobuf/wrappers.proto allow sending zero values in the
	// query string. Again, "repeated" fields are forbidden except on
	// primitive types or wrappers.proto types.

	output := newPointer(route.output)
	restResp, err := rt.invokeREST(req.Context(), route, input, output, req.URL, req.Header)
	var invalidBody []byte
	if err != nil {
		// If the RoundTrip call failed, report it now as an error. If we
		// received a request from the network, but we're unable to make sense
		// of the response (indicated by a *bodyErr error), construct a
		// convincing HTTP response for the caller so they can inspect the
		// unexpected body in a higher-level error.
		switch e := err.(type) {
		case *bodyErr:
			invalidBody = e.body
		default:
			return nil, err
		}
	}

	if sc := restResp.StatusCode; sc < 200 || 299 < sc || invalidBody != nil {
		resp := twirpErrorResponse(restResp, invalidBody)
		return resp, nil
	}

	resp, err := createResponse(output, req.Header.Get("Content-Type"), restResp)
	if resp != nil {
		resp.Request = req
	}
	return resp, err
}

// newPointer returns a pointer to a new instance of the type described by the
// argument. Since often the Type values passed around are already pointers,
// this function will remove one such layer of indirection (and then
// immediately add it back).
func newPointer(t reflect.Type) interface{} {
	if t.Kind() == reflect.Ptr {
		t = t.Elem()
	}
	return reflect.New(t).Interface()
}

type bodyErr struct {
	err  error
	body []byte
}

func (i *bodyErr) Error() string {
	return fmt.Sprintf("restclient: internal error, unexpected response body format: %v", i.err)
}

func (i *bodyErr) Cause() error {
	return i.err
}

// readBody consumes the provided body, deserializing it into the msg
// parameter, which must be a protobuf message value. It can decode JSON and
// binary protobuf data, depending on the specified contentType.
func readBody(msg interface{}, contentType string, body io.ReadCloser) error {
	if body == nil {
		return nil
	}

	buf, err := ioutil.ReadAll(body)
	if err != nil {
		return &bodyErr{body: buf, err: err}
	}

	pb, ok := msg.(proto.Message)
	if !ok {
		return fmt.Errorf("restclient: non-protobuf message type %T", msg)
	}

	if proto.MessageName(pb) == "google.protobuf.Empty" {
		// The response is probably "" or "{}", but it doesn't seem worthwhile
		// to detect here if it's something completely different. Consuming
		// the entire body before breaking out here allows reuse of HTTP/1.1
		// Keep-Alive connections.
		return nil
	}

	switch contentType {
	default:
		return fmt.Errorf("restclient: unknown content type %q", contentType)
	case "application/json":
		err = (&jsonpb.Unmarshaler{AllowUnknownFields: true}).Unmarshal(bytes.NewReader(buf), pb)
	case "application/protobuf":
		err = proto.Unmarshal(buf, pb)
	}
	if err != nil {
		return &bodyErr{body: buf, err: err}
	}

	return nil
}

// marshalBody serializes the protobuf message msg. It can encode messages in
// JSON or the binary protobuf encoding, depending on the specified
// contentType.
func marshalBody(msg interface{}, contentType string) ([]byte, error) {
	pb, ok := msg.(proto.Message)
	if !ok {
		return nil, fmt.Errorf("restclient: non-protobuf message type %T", msg)
	}

	switch contentType {
	default:
		return nil, fmt.Errorf("restclient: unknown content type %q", contentType)
	case "application/json":
		var buf bytes.Buffer
		err := (&jsonpb.Marshaler{}).Marshal(&buf, pb)
		if err != nil {
			return nil, err
		}
		return buf.Bytes(), nil
	case "application/protobuf":
		return proto.Marshal(pb)
	}
}

// decodeRequest consumes a Twirp-formatted request, returning the relevant
// route specification and the deserialized request message.
func (rt *roundTripper) decodeRequest(req *http.Request) (*routeSpec, interface{}, error) {
	if req.Body != nil {
		defer req.Body.Close()
	}

	if req.Method != "POST" {
		return nil, nil, fmt.Errorf("restclient: unexpected method %q", req.Method)
	}

	twirpMethod := twirpRPCMethod(req.URL.Path)
	if twirpMethod == "" {
		return nil, nil, fmt.Errorf("restclient: no route for uri path %q", req.URL.Path)
	}
	route := rt.routes[twirpMethod]
	if route == nil {
		return nil, nil, fmt.Errorf("restclient: no route for method %q", twirpMethod)
	}

	input := newPointer(route.input)

	err := readBody(input, req.Header.Get("Content-Type"), req.Body)
	if err != nil {
		return nil, nil, err
	}

	return route, input, nil
}

// encodeURIPath executes a URI path template, using the provided input
// message as a source for variable lookups.
//
// It modifies the input message by clearing any fields which were referenced
// in the template. Fields which were not referenced are not modified.
func encodeURIPath(input interface{}, tmpl *httproute.Template) (string, error) {
	fields := make(map[string]reflect.Value)

	pb, _ := input.(proto.Message)

	var innerErr error
	path, err := tmpl.Execute(func(key string) (string, bool) {
		val, err := lookupField(pb, key)
		if err != nil {
			if innerErr == nil {
				innerErr = err
			}
			return "", false
		}
		fields[key] = val

		// This is where we'd handle well-known types, but it's not yet clear
		// why someone would need support for them in a URI path. The purpose
		// of the wrapper types is to allow proto3 messages to have unset
		// fields, and unset fields don't have a place in the URI path since
		// the templates serialize all referenced fields regardless of their
		// value. The structural values don't have a well-defined single-
		// string format. This leaves Duration and (RFC3339) Timestamp, which
		// seem to be unlikely candidates for use in URI paths.
		//
		// As with most decisions in this package: if you have a
		// counterexample, please open an issue describing the use case.

		if val.Kind() == reflect.Ptr {
			val = val.Elem()
		}
		if val.Kind() == reflect.Struct {
			if innerErr == nil {
				innerErr = fmt.Errorf("restclient: non-primitive field of type %s not valid for lookup", val.Type().Name())
			}
			return "", false
		}
		return stringFormat(val), true
	})
	if innerErr != nil {
		return "", innerErr
	}
	if err != nil {
		return "", err
	}

	for _, val := range fields {
		val.Set(reflect.Zero(val.Type()))
	}

	return path, nil
}

// encodeBody prepares the provided input message for use in a JSON/HTTP
// request. Elements of the input message may be serialized as JSON; the
// others are returned as a url.Values map for use in the URI query string.
//
// When bodyField is set to "*", the entire message is serialized as JSON.
// When it is set to a single field, that field is serialized as JSON and the
// remaining fields will be represented in the url.Values map. If bodyField is
// not set, all fields will be represented in the url.Values map and the
// returned body will be empty.
//
// This function may modify the input message, by clearing the field that
// contains the request body.
func encodeBody(input interface{}, bodyField string) ([]byte, url.Values, error) {
	inputMsg, _ := input.(proto.Message)

	if bodyField == "*" {
		buf, err := marshalBody(inputMsg, "application/json")
		if err != nil {
			return nil, nil, err
		}
		return buf, nil, nil
	}

	var buf []byte
	if bodyField != "" {
		val, err := lookupField(inputMsg, bodyField)
		if err != nil {
			return nil, nil, err
		}

		if val.Kind() == reflect.Ptr && val.IsNil() {
			// jsonpb.Marshaler.Marshal (via marshalBody) panics when called
			// with a typed nil value of a protobuf message. We can't allow
			// that, so either return an error or provide a reasonable
			// serialization for a nil message.
			buf = []byte(`null`)
		} else {
			buf, err = marshalBody(val.Interface(), "application/json")
			if err != nil {
				return nil, nil, err
			}
			val.Set(reflect.Zero(val.Type()))
		}
	}

	v, err := leafFieldValues(inputMsg)
	if err != nil {
		return nil, nil, err
	}

	return buf, v, nil
}

// invokeREST makes an HTTP request to a JSON/HTTP API described by route and
// hosted at target. It uses input as the logical RPC request, referencing
// (and potentially modifying) it for any variables in the URI path pattern
// and the JSON request body, with any additional fields being used for the
// URI query string.
//
// The response of the API is deserialized into the output message. This
// function consumes and closes the Body of the *http.Response before
// returning it.
//
// This function may modify the input message.
func (rt *roundTripper) invokeREST(
	ctx context.Context, route *routeSpec, input, output interface{},
	target *url.URL, header http.Header) (*http.Response, error) {

	restPath, err := encodeURIPath(input, route.uriTemplate)
	if err != nil {
		return nil, err
	}

	restBody, v, err := encodeBody(input, route.bodyField)
	if err != nil {
		return nil, err
	}

	// Copy the URL, change the path to the executed template
	restURL := &url.URL{
		Scheme:   target.Scheme,
		Host:     target.Host,
		User:     target.User,
		Path:     restPath,
		RawQuery: v.Encode(),
	}

	restReq, err := http.NewRequest(route.httpMethod, restURL.String(), bytes.NewReader(restBody))
	if err != nil {
		return nil, err
	}
	restReq = restReq.WithContext(ctx)
	for k, v := range header {
		for _, v := range v {
			restReq.Header.Add(k, v)
		}
	}
	restReq.Header.Del("Content-Type")
	if restBody != nil {
		restReq.Header.Set("Content-Type", "application/json")
	}

	restResp, err := rt.base.RoundTrip(restReq)
	if err != nil {
		return nil, err
	}
	defer restResp.Body.Close()

	// Having a JSON response body is part of what makes an API a good fit for
	// use with this package.
	//
	// TODO: support arbitrary bodies via a special HttpBody protobuf type.
	err = readBody(output, "application/json", restResp.Body)
	if err != nil {
		return restResp, err
	}

	return restResp, nil
}

// createResponse returns a *http.Response for use with a Twirp client. It
// serializes the provided output message for use in the response Body, and
// copies many other fields from restResp. It can encode messages in JSON or
// the binary protobuf encoding, depending on the specified contentType.
func createResponse(output interface{}, contentType string, restResp *http.Response) (*http.Response, error) {
	// The call was basically a success. Return a successful response with the
	// data from the output value.
	resp := baseResponse(restResp)

	// Twirp expects the response Content-Type to match that of the request it
	// issued.
	resp.Header.Set("Content-Type", contentType)
	body, err := marshalBody(output, contentType)
	if err != nil {
		return nil, err
	}
	resp.Body = ioutil.NopCloser(bytes.NewReader(body))
	resp.ContentLength = int64(len(body))

	return resp, nil
}

func twirpErrorResponse(src *http.Response, invalidBody []byte) *http.Response {
	resp := baseResponse(src)

	type twerrJSON struct {
		Code string            `json:"code"`
		Msg  string            `json:"msg"`
		Meta map[string]string `json:"meta,omitempty"`
	}
	var twerr twerrJSON

	twerr.Code = map[int]string{
		400: "invalid_argument", // or "out_of_range"
		401: "unauthenticated",
		403: "permission_denied", // or "resource_exhausted"
		404: "not_found",         // or "bad_route"
		408: "deadline_exceeded", // or "canceled"
		409: "already_exists",    // or "conflict"
		412: "failed_precondition",
		500: "internal", // or "unknown" or "data_loss"
		501: "unimplemented",
		503: "unavailable",
	}[src.StatusCode]

	resp.StatusCode = src.StatusCode
	if twerr.Code == "" {
		twerr.Code = "internal"
		resp.StatusCode = 500
	}
	resp.Status = http.StatusText(resp.StatusCode)

	twerr.Msg = fmt.Sprintf("status code %d, content length %d, starting with %q",
		src.StatusCode, len(invalidBody), prefix(invalidBody, 30))
	twerr.Meta = map[string]string{
		"prefix": strconv.Quote(string(prefix(invalidBody, 300))),
	}

	// Twirp error responses are always JSON (instead of protobuf)
	resp.Header.Set("Content-Type", "application/json")

	// We know statically that json.Marshal cannot fail.
	body, _ := json.Marshal(&twerr)
	resp.Body = ioutil.NopCloser(bytes.NewReader(body))
	resp.ContentLength = int64(len(body))

	return resp
}

func prefix(p []byte, n int) []byte {
	if n > len(p) {
		n = len(p)
	}
	return p[:n]
}

func baseResponse(src *http.Response) *http.Response {
	resp := &http.Response{
		Status:     "200 OK",
		StatusCode: 200,
		Proto:      src.Proto,
		ProtoMajor: src.ProtoMajor,
		ProtoMinor: src.ProtoMinor,

		ContentLength:    -1,
		TransferEncoding: nil,
		Uncompressed:     false,

		Close: src.Close,
		TLS:   src.TLS,

		// The caller must fill in the Request field and the Body field

		// Below, we'll set the remaining fields: Header and Trailer
	}
	resp.Header = make(http.Header)
	for k, v := range src.Header {
		resp.Header[http.CanonicalHeaderKey(k)] = append([]string(nil), v...)
	}
	resp.Trailer = make(http.Header)
	for k, v := range src.Trailer {
		resp.Trailer[http.CanonicalHeaderKey(k)] = append([]string(nil), v...)
	}

	return resp
}
