package prototools

import (
	"encoding/json"
	"fmt"
	"io/ioutil"
	"reflect"
	"testing"

	"github.com/jhump/protoreflect/desc"
	"github.com/jhump/protoreflect/desc/builder"
	"github.com/jhump/protoreflect/dynamic"
	"github.com/stretchr/testify/require"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protodesc"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/types/descriptorpb"
	"google.golang.org/protobuf/types/dynamicpb"
	"google.golang.org/protobuf/types/known/anypb"
	"google.golang.org/protobuf/types/known/emptypb"
	"google.golang.org/protobuf/types/known/structpb"
	"google.golang.org/protobuf/types/known/timestamppb"

	"a.yandex-team.ru/library/go/test/yatest"
	taskletv2 "a.yandex-team.ru/tasklet/api/v2"
	testutils "a.yandex-team.ru/tasklet/experimental/internal/test_utils"
)

func Test_validateMessage_Integration(t *testing.T) {
	testutils.EnsureArcadiaTest(t)
	path, err := yatest.BinaryPath("tasklet/registry/common/tasklet-registry-common.protodesc")
	require.NoError(t, err)

	serializedFds, err := ioutil.ReadFile(path)
	require.NoError(t, err)
	fds := &descriptorpb.FileDescriptorSet{}
	require.NoError(t, proto.Unmarshal(serializedFds, fds))
	resolver, err := protodesc.NewFiles(fds)
	require.NoError(t, err)

	tests := []struct {
		message protoreflect.FullName
		err     error
	}{
		{
			"google.protobuf.Empty",
			nil,
		},
		{
			"google.protobuf.Timestamp",
			nil,
		},
		{
			"google.protobuf.Struct",
			nil,
		},
		{
			"not.existing.invalid.Message",
			ErrNotFound,
		},
		{
			"tasklet_examples.YamlToJsonInput",
			nil,
		},
		{
			"tasklet_examples.FakeYaMakeInput",
			nil,
		},
		{
			"tasklet_examples.FakeYaMakeInput.branch",
			ErrNotMessage,
		},
		{
			"tasklet_examples.FakeYaMakeInput.arc_token",
			ErrNotMessage,
		},
	}
	for _, tt := range tests {
		t.Run(
			string(tt.message), func(t *testing.T) {
				err := validateInterfaceMessageDescriptor(resolver, tt.message)
				if tt.err != nil {
					require.ErrorIs(t, err, tt.err)
				} else {
					require.NoError(t, err)
				}
			},
		)
	}
}

func Test_validateMessage(t *testing.T) {

	fds := &descriptorpb.FileDescriptorSet{}
	for _, fileDescriptor := range []protoreflect.FileDescriptor{
		anypb.File_google_protobuf_any_proto,
		emptypb.File_google_protobuf_empty_proto,
		timestamppb.File_google_protobuf_timestamp_proto,
		descriptorpb.File_google_protobuf_descriptor_proto,
		structpb.File_google_protobuf_struct_proto,
		taskletv2.File_tasklet_api_v2_extensions_proto,
		taskletv2.File_tasklet_api_v2_well_known_structures_proto,
	} {
		fds.File = append(fds.File, protodesc.ToFileDescriptorProto(fileDescriptor))
	}

	requiredSecretFieldOptions := &descriptorpb.FieldOptions{}
	proto.SetExtension(
		requiredSecretFieldOptions, taskletv2.E_TaskletField, &taskletv2.TaskletFieldOptions{
			Secret: &taskletv2.SecretOptions{Required: true},
		},
	)

	validSecret, err := desc.LoadMessageDescriptorForMessage((*taskletv2.SecretRef)(nil))
	require.NoError(t, err)
	validSecretBuilder, err := builder.FromMessage(validSecret)
	require.NoError(t, err)

	okMessage := builder.NewMessage("OkRequest").
		AddField(builder.NewField("id", builder.FieldTypeInt64())).
		AddField(builder.NewField("name", builder.FieldTypeString())).
		AddField(builder.NewField("secret", builder.FieldTypeMessage(validSecretBuilder)))

	fileBuilder := builder.
		NewFile("foo/bar.proto").
		SetPackageName("foo.bar").
		SetProto3(true).
		AddMessage(okMessage).
		AddMessage(
			builder.NewMessage("BadFieldKindRequest").
				AddField(
					builder.NewField("id", builder.FieldTypeInt64()).
						SetOptions(requiredSecretFieldOptions),
				).
				AddField(builder.NewField("name", builder.FieldTypeString())),
		).
		AddMessage(
			builder.NewMessage("BadFieldTypeRequest").
				AddField(
					builder.NewField("id", builder.FieldTypeMessage(okMessage)).
						SetOptions(requiredSecretFieldOptions),
				).
				AddField(builder.NewField("name", builder.FieldTypeString())),
		).
		AddMessage(
			builder.NewMessage("RepeatedSecret").
				AddField(
					builder.NewField("id", builder.FieldTypeMessage(validSecretBuilder)).
						SetOptions(requiredSecretFieldOptions).
						SetRepeated(),
				).
				AddField(builder.NewField("name", builder.FieldTypeString())),
		).
		AddMessage(
			builder.NewMessage("MapSecret").
				AddField(
					builder.NewMapField("id", builder.FieldTypeInt64(), builder.FieldTypeMessage(validSecretBuilder)).
						SetOptions(requiredSecretFieldOptions),
				).
				AddField(builder.NewField("name", builder.FieldTypeString())),
		)

	file, err := fileBuilder.Build()
	require.NoError(t, err)

	// NB: DEBUG: uncomment to print generated proto file
	// buf := bytes.NewBuffer(nil)
	// require.NoError(t, (&protoprint.Printer{}).PrintProtoFile(file, buf))
	// fmt.Print(buf.String())
	fds.File = append(fds.File, file.AsFileDescriptorProto())

	resolver, err := protodesc.NewFiles(fds)
	require.NoError(t, err)

	tests := []struct {
		message protoreflect.FullName
		err     error
		pat     string
	}{
		{
			"google.protobuf.Empty",
			nil,
			"",
		},
		{
			"google.protobuf.Timestamp",
			nil,
			"",
		},
		{
			"google.protobuf.Struct",
			nil,
			"",
		},
		{
			"tasklet.api.v2.GenericBinary",
			nil,
			"",
		},
		{
			"not.existing.invalid.Message",
			ErrNotFound,
			"",
		},
		{
			"foo.bar.OkRequest.id",
			ErrNotMessage,
			"",
		},
		{
			"foo.bar.OkRequest.secret",
			ErrNotMessage,
			"",
		},
		{
			"foo.bar",
			ErrNotFound,
			"",
		},
		{
			"foo.bar.OkRequest",
			nil,
			"",
		},
		{
			"foo.bar.BadFieldKindRequest",
			ErrBadAnnotation,
			"Field: \"foo.bar.BadFieldKindRequest.id\", Kind: int64",
		},
		{
			"foo.bar.BadFieldTypeRequest",
			ErrBadAnnotation,
			"ActualType: \"foo.bar.OkRequest\"",
		},
		{
			"foo.bar.RepeatedSecret",
			ErrBadAnnotation,
			"Repeated and map",
		},
		{
			"foo.bar.MapSecret",
			ErrBadAnnotation,
			"Repeated and map",
		},
	}
	for _, tt := range tests {
		t.Run(
			string(tt.message), func(t *testing.T) {
				err := validateInterfaceMessageDescriptor(resolver, tt.message)
				if tt.err != nil {
					require.ErrorIs(t, err, tt.err)
					require.Contains(t, err.Error(), tt.pat)
				} else {
					require.NoError(t, err)
				}
			},
		)
	}

}

func TestMakeIOMessages(t *testing.T) {

	fds := &descriptorpb.FileDescriptorSet{}
	for _, fileDescriptor := range []protoreflect.FileDescriptor{
		anypb.File_google_protobuf_any_proto,
		emptypb.File_google_protobuf_empty_proto,
		timestamppb.File_google_protobuf_timestamp_proto,
		descriptorpb.File_google_protobuf_descriptor_proto,
		structpb.File_google_protobuf_struct_proto,
		taskletv2.File_tasklet_api_v2_extensions_proto,
		taskletv2.File_tasklet_api_v2_well_known_structures_proto,
	} {
		fds.File = append(fds.File, protodesc.ToFileDescriptorProto(fileDescriptor))
	}

	tests := []struct {
		name     string
		ioSchema *taskletv2.IOSimpleSchemaProto
		want     *dynamicpb.Message
		want1    *dynamicpb.Message
		wantErr  bool
	}{
		{
			"simple",
			&taskletv2.IOSimpleSchemaProto{
				SchemaHash:    "",
				InputMessage:  "google.protobuf.Empty",
				OutputMessage: "tasklet.api.v2.GenericBinary",
			},
			dynamicpb.NewMessage((&emptypb.Empty{}).ProtoReflect().Descriptor()),
			dynamicpb.NewMessage((&taskletv2.GenericBinary{}).ProtoReflect().Descriptor()),
			false,
		},
		{
			"bad_message",
			&taskletv2.IOSimpleSchemaProto{
				SchemaHash:    "",
				InputMessage:  "foo.invalid.bar",
				OutputMessage: "tasklet.api.v2.GenericBinary",
			},
			nil,
			nil,
			true,
		},
	}
	for _, tt := range tests {
		t.Run(
			tt.name, func(t *testing.T) {
				got, got1, err := MakeIOMessages(fds, tt.ioSchema)
				require.Equalf(t, err != nil, tt.wantErr, "error: %v, wantErr: %v", err, tt.wantErr)
				if err != nil {
					return
				}
				require.Equal(t, got.Descriptor().FullName(), tt.want.Descriptor().FullName())
				require.Equal(t, got1.Descriptor().FullName(), tt.want1.Descriptor().FullName())
			},
		)
	}
}

func TestParseMessage(t *testing.T) {

	fds := &descriptorpb.FileDescriptorSet{}
	for _, fileDescriptor := range []protoreflect.FileDescriptor{
		anypb.File_google_protobuf_any_proto,
		emptypb.File_google_protobuf_empty_proto,
		structpb.File_google_protobuf_struct_proto,
		descriptorpb.File_google_protobuf_descriptor_proto,
		taskletv2.File_tasklet_api_v2_extensions_proto,
		taskletv2.File_tasklet_api_v2_well_known_structures_proto,
	} {
		fds.File = append(fds.File, protodesc.ToFileDescriptorProto(fileDescriptor))
	}

	requiredSecretFieldOptions := &descriptorpb.FieldOptions{}
	proto.SetExtension(
		requiredSecretFieldOptions, taskletv2.E_TaskletField, &taskletv2.TaskletFieldOptions{
			Secret: &taskletv2.SecretOptions{Required: true},
		},
	)

	secretRefDesc, err := desc.LoadMessageDescriptorForMessage((*taskletv2.SecretRef)(nil))
	require.NoError(t, err)

	anyMsg, err := desc.LoadMessageDescriptorForMessage((*anypb.Any)(nil))
	require.NoError(t, err)
	// anyMsgBuilder, err := builder.FromMessage(anyMsg)
	// require.NoError(t, err)

	nestedMessage := builder.NewMessage("Nested").
		AddField(builder.NewField("id", builder.FieldTypeInt64())).
		AddField(builder.NewField("name", builder.FieldTypeString())).
		AddField(builder.NewField("internal", builder.FieldTypeImportedMessage(anyMsg)))

	packedMessage := builder.NewMessage("Packed").
		AddField(builder.NewField("str", builder.FieldTypeString())).
		AddField(builder.NewField("bl", builder.FieldTypeBool()))

	mustMarshall := func(b []byte, err error) []byte {
		if err != nil {
			panic(err)
		}
		return b
	}

	fileBuilder := builder.
		NewFile("lok/kek.proto").
		SetPackageName("lol.kek").
		SetProto3(true).
		AddMessage(nestedMessage).
		AddMessage(packedMessage).
		AddMessage(
			builder.NewMessage("Simple").
				AddField(builder.NewField("id", builder.FieldTypeInt64())).
				AddField(builder.NewField("name", builder.FieldTypeString())),
		).
		AddMessage(
			builder.NewMessage("WithNested").
				AddField(builder.NewField("name", builder.FieldTypeString())).
				AddField(
					builder.NewField("nested", builder.FieldTypeMessage(nestedMessage)),
				).
				AddField(
					builder.NewField("secret", builder.FieldTypeImportedMessage(secretRefDesc)).
						SetOptions(requiredSecretFieldOptions),
				),
		)
	file, err := fileBuilder.Build()
	require.NoError(t, err)
	// NB: DEBUG: uncomment to print generated proto file
	// buf := bytes.NewBuffer(nil)
	// require.NoError(t, (&protoprint.Printer{}).PrintProtoFile(file, buf))
	// fmt.Print(buf.String())
	fds.File = append(fds.File, file.AsFileDescriptorProto())

	packedMessagePayload := &anypb.Any{}
	{
		msg := dynamic.NewMessage(file.FindMessage("lol.kek.Packed"))
		msg.SetFieldByName("str", "packed_str_str")
		msg.SetFieldByName("bl", true)
		packedMessagePayload.Value = mustMarshall(msg.Marshal())
		// packedMessagePayload.TypeUrl = "type.googleapis.com/" + msg.GetMessageDescriptor().GetFullyQualifiedName()
	}
	// for linter check
	_, _ = proto.Marshal(packedMessagePayload)

	packedComplexMsg := &anypb.Any{}
	var jsonComplexMessage []byte
	{
		nestedMsg := dynamic.NewMessage(file.FindMessage("lol.kek.Nested"))
		nestedMsg.SetFieldByName("id", 1543)
		nestedMsg.SetFieldByName("name", "hohoho")
		// NB: does not work. Why?
		// nestedMsg.SetFieldByName("internal", packedMessagePayload)

		secretMsg := dynamic.NewMessage(secretRefDesc)
		secretMsg.SetFieldByName("id", "sec1-vvv")
		secretMsg.SetFieldByName("version", "ver1-vvv")
		secretMsg.SetFieldByName("key", "zz_token")

		rootMsg := dynamic.NewMessage(file.FindMessage("lol.kek.WithNested"))
		rootMsg.SetFieldByName("name", "eminem")
		rootMsg.SetFieldByName("nested", nestedMsg)
		rootMsg.SetFieldByName("secret", secretMsg)
		require.NoError(t, rootMsg.ValidateRecursive())
		packedComplexMsg.Value = mustMarshall(rootMsg.Marshal())
		// packedComplexMsg.TypeUrl = "type.googleapis.com/" + rootMsg.GetMessageDescriptor().GetFullyQualifiedName()

		jsonComplexMessage = mustMarshall(rootMsg.MarshalJSON())
	}

	packedSimpleMessage := &anypb.Any{}
	var jsonSimpleMessage []byte
	{
		msg := dynamic.NewMessage(file.FindMessage("lol.kek.Simple"))
		msg.SetFieldByName("id", 57)
		msg.SetFieldByName("name", "billie eilish")
		packedSimpleMessage.Value = mustMarshall(msg.Marshal())
		// packedSimpleMessage.TypeUrl = "type.googleapis.com/" + msg.GetMessageDescriptor().GetFullyQualifiedName()
		jsonSimpleMessage = mustMarshall(msg.MarshalJSON())
	}

	// NB: do not mind marshalling to Any above. Will switch to any handling in api as soon as Schema Registry is ready

	type args struct {
		inputMessageName string
		input            *taskletv2.ExecutionInput
	}
	tests := []struct {
		name    string
		args    args
		want    []byte
		wantErr bool
	}{
		{
			"Simple",
			args{
				"lol.kek.Simple",
				&taskletv2.ExecutionInput{SerializedData: packedSimpleMessage.Value},
			},
			jsonSimpleMessage,
			false,
		},
		{
			"Complex",
			args{
				"lol.kek.WithNested",
				&taskletv2.ExecutionInput{SerializedData: packedComplexMsg.Value},
			},
			jsonComplexMessage,
			false,
		},
		{
			"InvalidPayload",
			args{
				"lol.kek.WithNested",
				&taskletv2.ExecutionInput{SerializedData: packedSimpleMessage.Value},
			},
			nil,
			true,
		},
	}
	for _, tt := range tests {
		t.Run(
			tt.name, func(t *testing.T) {
				got, err := ParseMessage(fds, tt.args.inputMessageName, tt.args.input)
				if (err != nil) != tt.wantErr {
					t.Errorf("ParseMessage() error = %v, wantErr %v", err, tt.wantErr)
					return
				}
				if err != nil {
					return
				}
				got.ProtoReflect().Range(
					func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
						fmt.Printf("%v: %v [%v]\n", fd.FullName(), v.Interface(), reflect.TypeOf(v.Interface()))
						return true
					},
				)
				var parsedWant interface{}
				require.NoError(t, json.Unmarshal(tt.want, &parsedWant))
				var parsedGot interface{}
				require.NoError(t, json.Unmarshal(mustMarshall(protojson.Marshal(got)), &parsedGot))
				require.Equal(t, parsedWant, parsedGot)
			},
		)
	}
}
