package client

import (
	swat_temporal_client_proto "a.yandex-team.ru/infra/temporal/swat/client/proto"
	"fmt"
	commonpb "go.temporal.io/api/common/v1"
	"go.temporal.io/sdk/converter"
	"google.golang.org/protobuf/types/known/anypb"

	"google.golang.org/protobuf/proto"
	"reflect"
)

type SwatProtoJSONPayloadConverter struct {
	// Introduced to work around https://community.temporal.io/t/custom-serialization-in-go/97
	// See https://st.yandex-team.ru/AWACS-868 for a bit more details
	*converter.ProtoJSONPayloadConverter
}

func castToNonEmptyProtoMessagesSlice(value interface{}) ([]proto.Message, bool) {
	var rv []proto.Message
	if reflect.TypeOf(value).Kind() != reflect.Slice {
		return rv, false
	}
	s := reflect.ValueOf(value)
	if s.Len() == 0 {
		return rv, false
	}
	for i := 0; i < s.Len(); i++ {
		item := s.Index(i)
		if msg, ok := item.Interface().(proto.Message); ok {
			rv = append(rv, msg)
		} else {
			return rv, false
		}
	}
	return rv, true
}

func (c *SwatProtoJSONPayloadConverter) ToPayload(value interface{}) (*commonpb.Payload, error) {
	if messages, ok := castToNonEmptyProtoMessagesSlice(value); ok {
		container := &swat_temporal_client_proto.SliceOfMessages{}
		for _, m := range messages {
			any, err := anypb.New(m)
			if err != nil {
				return nil, err
			}
			container.Messages = append(container.Messages, any)
		}
		value = container
	}

	return c.ProtoJSONPayloadConverter.ToPayload(value)
}

func (c *SwatProtoJSONPayloadConverter) FromPayload(payload *commonpb.Payload, valuePtr interface{}) error {
	container := &swat_temporal_client_proto.SliceOfMessages{}
	err := c.ProtoJSONPayloadConverter.FromPayload(payload, container)

	if err == nil {
		reflectedPtrValue := reflect.ValueOf(valuePtr)
		if reflectedPtrValue.Kind() != reflect.Ptr {
			return fmt.Errorf("type: %T: %w", valuePtr, converter.ErrValuePtrIsNotPointer)
		}
		reflectedSliceValue := reflectedPtrValue.Elem()
		if reflectedSliceValue.Kind() != reflect.Slice {
			return fmt.Errorf("type: %T: %w", valuePtr, converter.ErrUnableToDecode)
		}
		reflectedSliceType := reflectedSliceValue.Type()
		sliceItemType := reflectedSliceType.Elem()
		if sliceItemType.Kind() != reflect.Ptr {
			return fmt.Errorf("type: %s: %w", sliceItemType, converter.ErrValuePtrIsNotPointer)
		}

		rv := reflect.MakeSlice(reflectedSliceType, 0, len(container.Messages))
		for _, any := range container.Messages {
			msg := reflect.New(sliceItemType.Elem())
			err = any.UnmarshalTo(msg.Interface().(proto.Message))
			if err != nil {
				return fmt.Errorf("failed to unmarshal any to %T", msg)
			}
			rv = reflect.Append(rv, msg)
		}
		reflectedSliceValue.Set(rv)
		return nil
	}

	err = c.ProtoJSONPayloadConverter.FromPayload(payload, valuePtr)
	if err != nil {
		return err
	}
	return nil
}

func NewSwatProtoJSONPayloadConverter() *SwatProtoJSONPayloadConverter {
	return &SwatProtoJSONPayloadConverter{
		converter.NewProtoJSONPayloadConverter(),
	}
}

func GetDefaultSwatDataConverter() converter.DataConverter {
	/*
		The same as the converter.GetDefaultSwatDataConverter(), but uses our
		SwatProtoJSONPayloadConverter built upon standard converter.ProtoJSONPayloadConverter.
	*/
	return converter.NewCompositeDataConverter(
		converter.NewNilPayloadConverter(),
		converter.NewByteSlicePayloadConverter(),
		NewSwatProtoJSONPayloadConverter(),
		converter.NewJSONPayloadConverter(),
	)
}
