package xmodels

import (
	"testing"
	"time"

	"github.com/c2h5oh/datasize"
	"github.com/gofrs/uuid"
	"github.com/stretchr/testify/require"
	"google.golang.org/genproto/googleapis/rpc/errdetails"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/types/known/timestamppb"

	"a.yandex-team.ru/tasklet/api/v2"
	"a.yandex-team.ru/tasklet/experimental/internal/consts"
	"a.yandex-team.ru/tasklet/experimental/internal/grpcvalid"
)

func TestValidateBuild(t *testing.T) {

	validMeta := &taskletv2.BuildMeta{
		Id:        uuid.Must(uuid.NewV4()).String(),
		CreatedAt: timestamppb.New(time.Now()),
		Tasklet:   "task_let",
		Namespace: "name_space",
		TaskletId: uuid.Must(uuid.NewV4()).String(),
		Revision:  43,
	}

	patchMeta := func(src *taskletv2.BuildMeta) *taskletv2.BuildMeta {
		rv := proto.Clone(validMeta).(*taskletv2.BuildMeta)
		proto.Merge(rv, src)
		return rv
	}
	validSpec := &taskletv2.BuildSpec{
		Description: "blah",
		ComputeResources: &taskletv2.ComputeResources{
			VcpuLimit:   1543,
			MemoryLimit: 128 * datasize.MB.Bytes(),
		},
		LaunchSpec: &taskletv2.LaunchSpec{
			Type: "binary",
		},
		Payload: &taskletv2.BuildSpec_Payload{
			SandboxResourceId: 1543,
		},
		Workspace: &taskletv2.BuildSpec_Workspace{
			StorageClass: taskletv2.EStorageClass_E_STORAGE_CLASS_HDD,
			StorageSize:  1024 * 1024 * 100,
		},
		Schema: &taskletv2.IOSchema{
			SimpleProto: &taskletv2.IOSimpleSchemaProto{
				SchemaHash:    "booo!",
				InputMessage:  "google.protobuf.Empty",
				OutputMessage: "google.protobuf.Empty",
			},
		},
	}

	patchSpec := func(src *taskletv2.BuildSpec) *taskletv2.BuildSpec {
		rv := proto.Clone(validSpec).(*taskletv2.BuildSpec)
		proto.Merge(rv, src)
		return rv
	}

	tests := []struct {
		name  string
		build *taskletv2.Build
		scope grpcvalid.RequestScope
		want  grpcvalid.FieldViolations
	}{
		{
			"ok_request",
			&taskletv2.Build{
				Meta: validMeta,
				Spec: validSpec,
			},
			grpcvalid.ScopeRequest,
			nil,
		},
		{
			"ok_response",
			&taskletv2.Build{
				Meta: validMeta,
				Spec: validSpec,
			},
			grpcvalid.ScopeResponse,
			nil,
		},
		{
			"fail_tasklet",
			&taskletv2.Build{
				Meta: patchMeta(&taskletv2.BuildMeta{Tasklet: "z"}),
				Spec: validSpec,
			},
			grpcvalid.ScopeRequest,
			grpcvalid.FieldViolations{
				&errdetails.BadRequest_FieldViolation{Field: ".meta.tasklet"},
			},
		},
		{
			"fail_cpulim",
			&taskletv2.Build{
				Meta: validMeta,
				Spec: patchSpec(
					&taskletv2.BuildSpec{
						ComputeResources: &taskletv2.ComputeResources{VcpuLimit: 1},
					},
				),
			},
			grpcvalid.ScopeRequest,
			grpcvalid.FieldViolations{
				&errdetails.BadRequest_FieldViolation{Field: ".spec.compute_resources.vcpu_limit"},
			},
		},
		{
			"fail_jdk_path",
			&taskletv2.Build{
				Meta: validMeta,
				Spec: patchSpec(
					&taskletv2.BuildSpec{
						LaunchSpec: &taskletv2.LaunchSpec{
							Type: consts.LaunchTypeJava17,
							Jdk:  &taskletv2.LaunchSpec_JDKOptions{MainClass: ""},
						},
					},
				),
			},
			grpcvalid.ScopeRequest,
			grpcvalid.FieldViolations{
				&errdetails.BadRequest_FieldViolation{Field: ".spec.launch_spec.jdk.main_class"},
			},
		},
		{
			"fail_schema",
			&taskletv2.Build{
				Meta: validMeta,
				Spec: patchSpec(
					&taskletv2.BuildSpec{
						Schema: &taskletv2.IOSchema{
							SimpleProto: &taskletv2.IOSimpleSchemaProto{
								SchemaHash:   "",
								InputMessage: "###INVALID!!!!@##!!!@@@",
							},
						},
					},
				),
			},
			grpcvalid.ScopeRequest,
			grpcvalid.FieldViolations{
				&errdetails.BadRequest_FieldViolation{Field: ".spec.schema.simple_proto.input_message"},
			},
		},
	}
	for _, tt := range tests {
		t.Run(
			tt.name, func(t *testing.T) {
				tt := tt
				t.Parallel()
				got := validateBuild("", tt.build, tt.scope)
				if len(tt.want) == 0 {
					require.Len(t, got, 0)
				} else {
					for _, fv := range tt.want {
						has := false
						for _, item := range got {
							if item.Field == fv.Field {
								has = true
								break
							}
						}
						require.True(t, has, "Missing violation %+v. Want: %+v. Got: %+v", fv, tt.want, got)
					}
					for _, fv := range got {
						has := false
						for _, item := range tt.want {
							if item.Field == fv.Field {
								has = true
								break
							}
						}
						require.True(t, has, "Extra violation %v", fv)
					}
				}

			},
		)
	}
}
