package grpcgateway

import (
	"a.yandex-team.ru/library/go/core/xerrors"
	"github.com/golang/protobuf/proto"
	"github.com/grpc-ecosystem/grpc-gateway/runtime"
	"io"
	"io/ioutil"
	"reflect"
)

// BinaryProtoMarshaller is similar to `github.com/grpc-ecosystem/grpc-gateway/runtime/ProtoMarshaller`
// but is able to instantiate nils similar to how `JSONPb` marshaller does it via `decodeNonProtoField` calls
type BinaryProtoMarshaller struct{}

var typeProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()

// ContentType always returns "application/octet-stream".
func (*BinaryProtoMarshaller) ContentType() string {
	return "application/octet-stream"
}

// Marshal marshals "value" into Proto
func (*BinaryProtoMarshaller) Marshal(value interface{}) ([]byte, error) {
	message, ok := value.(proto.Message)
	if !ok {
		return nil, xerrors.New("unable to marshal non proto field")
	}
	return proto.Marshal(message)
}

// Unmarshal unmarshals proto "data" into "value"
func (*BinaryProtoMarshaller) Unmarshal(data []byte, value interface{}) error {
	message, ok := value.(proto.Message)
	if !ok {
		return decodeNonProtoField(data, value)
	}
	return proto.Unmarshal(data, message)
}

// NewDecoder returns a Decoder which reads proto stream from "reader".
func (marshaller *BinaryProtoMarshaller) NewDecoder(reader io.Reader) runtime.Decoder {
	return runtime.DecoderFunc(func(value interface{}) error {
		buffer, err := ioutil.ReadAll(reader)
		if err != nil {
			return err
		}
		return marshaller.Unmarshal(buffer, value)
	})
}

// NewEncoder returns an Encoder which writes proto stream into "writer".
func (marshaller *BinaryProtoMarshaller) NewEncoder(writer io.Writer) runtime.Encoder {
	return runtime.EncoderFunc(func(value interface{}) error {
		buffer, err := marshaller.Marshal(value)
		if err != nil {
			return err
		}
		_, err = writer.Write(buffer)
		if err != nil {
			return err
		}

		return nil
	})
}

func decodeNonProtoField(data []byte, v interface{}) error {
	rv := reflect.ValueOf(v)
	if rv.Kind() != reflect.Ptr {
		return xerrors.Errorf("%T is not a pointer", v)
	}
	for rv.Kind() == reflect.Ptr {
		if rv.IsNil() {
			rv.Set(reflect.New(rv.Type().Elem()))
		}
		if rv.Type().ConvertibleTo(typeProtoMessage) {
			return proto.Unmarshal(data, rv.Interface().(proto.Message))
		}
		rv = rv.Elem()
	}
	return xerrors.Errorf("unable to unmarshal non proto field")
}
