package proto

import (
	"errors"
	"reflect"
	"strings"

	"github.com/golang/protobuf/proto"
	"google.golang.org/genproto/protobuf/field_mask"
)

var (
	// ErrTypeMismatch is returned when trying to merge protobuffers of different types.
	ErrTypeMismatch = errors.New("can't merge protobuffers of different types")

	// ErrNonMessageSubField is returned when a FieldMask path specifies a sub-field for a non-message field.
	ErrNonMessageSubField = errors.New("can't specify a sub-field for a non-message field")

	// ErrNonLastRepeatedField is returned when a repeated field isn't the last component of a FieldMask path.
	ErrNonLastRepeatedField = errors.New("repeated fields are only allowed at the end of a mask path")
)

func copyProtoProperty(dst, src proto.Message, fieldName string) error {
	fieldNameParts := strings.SplitN(fieldName, ".", 2)
	srcType := reflect.TypeOf(src).Elem()
	structProperties := proto.GetProperties(srcType)
	for _, property := range structProperties.Prop {
		if property.OrigName == fieldNameParts[0] {
			srcFieldValue := reflect.ValueOf(src).Elem().FieldByName(property.Name)
			dstFieldValue := reflect.ValueOf(dst).Elem().FieldByName(property.Name)
			if len(fieldNameParts) > 1 {
				if property.Repeated {
					return ErrNonLastRepeatedField
				}
				var srcMessage, dstMessage proto.Message
				var ok bool
				if srcMessage, ok = srcFieldValue.Interface().(proto.Message); !ok {
					return ErrNonMessageSubField
				}
				if dstMessage, ok = dstFieldValue.Interface().(proto.Message); !ok {
					return ErrNonMessageSubField
				}
				if err := copyProtoProperty(dstMessage, srcMessage, fieldNameParts[1]); err != nil {
					return err
				}
			} else {
				dstFieldValue.Set(srcFieldValue)
			}
		}
	}
	return nil
}

// MaskedMerge merges the src protobuffer into the dst while respecting a FieldMask.
// See: https://godoc.org/google.golang.org/genproto/protobuf/field_mask
func MaskedMerge(dst, src proto.Message, mask *field_mask.FieldMask) error {
	srcType := reflect.TypeOf(src).Elem()
	dstType := reflect.TypeOf(dst).Elem()
	if srcType != dstType {
		return ErrTypeMismatch
	}

	for _, fieldName := range mask.GetPaths() {
		err := copyProtoProperty(dst, src, fieldName)
		if err != nil {
			return err
		}
	}
	return nil
}
