package loadroutes

import (
	"bytes"
	"compress/gzip"
	"fmt"
	"io/ioutil"

	"code.justin.tv/common/alvin/internal/httpproto"
	"code.justin.tv/common/alvin/internal/httpproto/google_api"
	"github.com/golang/protobuf/proto"
	"github.com/golang/protobuf/protoc-gen-go/descriptor"
	"github.com/pkg/errors"
)

// FromServiceDescriptor returns the list of method specifications encoded by
// the provided ServiceDescriptor, including instructions for how to convert
// logical RPC calls to and from HTTP interactions.
func FromServiceDescriptor(protoPackage string, serviceDesc *descriptor.ServiceDescriptorProto) ([]*httpproto.MethodSpec, error) {
	var specs []*httpproto.MethodSpec

	if serviceDesc == nil {
		return nil, fmt.Errorf("no service descriptor specified")
	}

	svcname := protoPackage + "." + serviceDesc.GetName()
	for _, methodDesc := range serviceDesc.GetMethod() {
		methodName := "/" + svcname + "/" + methodDesc.GetName()

		if methodDesc == nil {
			continue
		}
		opts := methodDesc.Options
		if opts == nil {
			continue
		}

		httpRule, err := getHttpRule(opts)
		if err != nil {
			continue
		}

		spec := httpproto.MethodSpec{
			FullName:   methodName,
			InputType:  methodDesc.GetInputType(),
			OutputType: methodDesc.GetOutputType(),
			Routes:     routeSpecs(httpRule),
		}
		specs = append(specs, &spec)
	}

	return specs, nil
}

// routeSpecs returns the list of route specifications encoded by the provided
// HttpRule value. The provided HttpRule value may have a single layer of
// nesting, as described in the type's documentation.
func routeSpecs(rules *google_api.HttpRule) []*httpproto.RouteSpec {
	var specs []*httpproto.RouteSpec

	for _, rule := range append([]*google_api.HttpRule{rules}, rules.GetAdditionalBindings()...) {
		if rule == nil {
			continue
		}
		var spec httpproto.RouteSpec
		spec.BodyField = rule.Body
		spec.HTTPMethod, spec.PathPattern = readPattern(rule)
		if spec.PathPattern != "" {
			specs = append(specs, &spec)
		}
	}

	return specs
}

// readPattern returns the HTTP method and URI template for the provided non-
// nested rule.
func readPattern(rule *google_api.HttpRule) (kind, path string) {
	switch r := rule.GetPattern().(type) {
	case *google_api.HttpRule_Get:
		if r != nil {
			kind, path = "GET", r.Get
		}
	case *google_api.HttpRule_Put:
		if r != nil {
			kind, path = "PUT", r.Put
		}
	case *google_api.HttpRule_Post:
		if r != nil {
			kind, path = "POST", r.Post
		}
	case *google_api.HttpRule_Delete:
		if r != nil {
			kind, path = "DELETE", r.Delete
		}
	case *google_api.HttpRule_Patch:
		if r != nil {
			kind, path = "PATCH", r.Patch
		}
	case *google_api.HttpRule_Custom:
		if r != nil {
			if cust := r.Custom; cust != nil {
				kind, path = cust.Kind, cust.Path
			}
		}
	}
	return kind, path
}

// FromProtoFileName reads method descriptions from a protobuf file
// descriptor. It loads the descriptor from the global registry in the
// github.com/golang/protobuf/proto package.
func FromProtoFileName(filename string) ([]*httpproto.MethodSpec, error) {
	return getMethodSpecs(func() (*descriptor.FileDescriptorProto, error) {
		return getProtoFileDescriptorFromFileName(filename)
	})
}

// TwirpServer is redefined here because importing restclient form here would
// create an import cycle
type TwirpServer interface {
	ServiceDescriptor() ([]byte, int)
}

// FromTwirpServer reads method descriptions from a Twirp server
func FromTwirpServer(twirpServer TwirpServer) ([]*httpproto.MethodSpec, error) {
	gzd, serviceIndex := twirpServer.ServiceDescriptor()
	return getServiceMethodSpecs(serviceIndex, func() (*descriptor.FileDescriptorProto, error) {
		return getProtoFileDescriptor(gzd)
	})
}

// getMethodSpecs calls the provided function to obtain a protobuf file
// descriptor, returning the method specifications that it describes.
//
// The odd design -- accepting a function that returns the descriptor --
// allows better test coverage.
func getMethodSpecs(getDescriptor func() (*descriptor.FileDescriptorProto, error)) ([]*httpproto.MethodSpec, error) {
	fd, err := getDescriptor()
	if err != nil {
		return nil, err
	}
	var specs []*httpproto.MethodSpec
	for _, sd := range fd.GetService() {
		ms, err := FromServiceDescriptor(fd.GetPackage(), sd)
		if err != nil {
			return nil, err
		}
		specs = append(specs, ms...)
	}
	return specs, nil
}

// getServiceMethodSpecs calls the provided function to obtain a protobuf file
// descriptor, retrieves the service descriptor from the specified index, and
// returns the method specifications that only that service describes.
func getServiceMethodSpecs(index int, getDescriptor func() (*descriptor.FileDescriptorProto, error)) ([]*httpproto.MethodSpec, error) {
	fd, err := getDescriptor()
	if err != nil {
		return nil, err
	}

	if index >= len(fd.GetService()) {
		return nil, errors.Errorf("failed to load http routes from file descriptor, no service at index %d", index)
	}

	return FromServiceDescriptor(fd.GetPackage(), fd.GetService()[index])
}

// getProtoFileDescriptor returns the FileDescriptorProto message deserialized
// from the input buffer
func getProtoFileDescriptor(gzd []byte) (*descriptor.FileDescriptorProto, error) {
	rd, err := gzip.NewReader(bytes.NewReader(gzd))
	if err != nil {
		return nil, errors.Errorf("descriptor could not be decompressed")
	}

	desc, err := ioutil.ReadAll(rd)
	if err != nil {
		return nil, errors.Errorf("descriptor could not be decompressed")
	}

	var fd descriptor.FileDescriptorProto
	if err = proto.Unmarshal(desc, &fd); err != nil {
		return nil, errors.Errorf("descriptor could not be deserialized")
	}

	return &fd, nil
}

// getProtoFileDescriptorFromFileName returns the FileDescriptorProto from the
// specified filename
func getProtoFileDescriptorFromFileName(filename string) (*descriptor.FileDescriptorProto, error) {
	buf, err := getProtoFileDescriptorBytes(filename)
	if err != nil {
		return nil, err
	}
	return getProtoFileDescriptor(buf)
}

// getProtoFileDescriptorBytes returns the serialized compressed  bytes of the
// FileDescriptorProto message for the given filename.
func getProtoFileDescriptorBytes(filename string) ([]byte, error) {
	gzd := proto.FileDescriptor(filename)
	if gzd == nil {
		return nil, errors.Errorf("descriptor for file %q not found", filename)
	}

	return gzd, nil
}
