package handler

import (
	"context"
	"errors"
	"fmt"
	"strings"
	"time"

	acmodel "a.yandex-team.ru/tasklet/experimental/internal/access/model"
	"github.com/gofrs/uuid"
	"google.golang.org/genproto/googleapis/rpc/errdetails"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/types/known/structpb"
	"google.golang.org/protobuf/types/known/timestamppb"

	"a.yandex-team.ru/library/go/core/log/ctxlog"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/tasklet/api/v2"
	"a.yandex-team.ru/tasklet/experimental/internal/consts"
	"a.yandex-team.ru/tasklet/experimental/internal/handler/xmodels"
	"a.yandex-team.ru/tasklet/experimental/internal/prototools"
	"a.yandex-team.ru/tasklet/experimental/internal/requestctx"
	"a.yandex-team.ru/tasklet/experimental/internal/storage/common"
)

func (t *APIHandler) Execute(ctx context.Context, request *taskletv2.ExecuteRequest) (
	*taskletv2.ExecuteResponse,
	error,
) {
	userAuth, err := t.requireUserAuth(ctx)
	if err != nil {
		return nil, err
	}

	if strings.HasPrefix(userAuth.Login(), "robot-") && request.GetRequirements().GetAccountId() == "" {
		st := status.New(codes.InvalidArgument, "Robot user requires sandbox quota for processing. None provided")
		v := &errdetails.BadRequest_FieldViolation{
			Field:       "requirements.account_id",
			Description: "must not be empty",
		}
		st, _ = st.WithDetails(
			&errdetails.BadRequest{
				FieldViolations: []*errdetails.BadRequest_FieldViolation{v},
			},
		)
		return nil, st.Err()
	}

	if st := xmodels.ValidateExecuteRequest(request); st != nil {
		return nil, st.Err()
	}

	tasklet, err := t.findTasklet(ctx, request.Tasklet, request.Namespace, false)
	if err != nil {
		return nil, err
	}
	ns, err := t.findNamespace(ctx, tasklet.Meta.Namespace, false)
	if err != nil {
		return nil, err
	}
	if err = t.permissionsChecker.CheckPermissions(
		ctx,
		userAuth.Login(),
		acmodel.BaseTaskletRun,
		&acmodel.AccessData{Namespace: ns, Tasklet: tasklet},
	); err != nil {
		return nil, err
	}

	labelName := request.Label
	if labelName == "" {
		labelName = tasklet.GetSpec().GetTrackingLabel()
		if labelName == "" {
			return nil, status.Errorf(codes.FailedPrecondition, "Label not specified")
		}
	}

	label, err := t.findLabel(ctx, tasklet, labelName)
	if err != nil {
		return nil, err
	}
	targetBuildID := label.GetSpec().GetBuildId()
	ctxlog.Infof(ctx, t.Log, "Starting tasklet execution. TaskletId: %q, BuildId: %q", tasklet.Meta.Id, targetBuildID)

	execution := &taskletv2.Execution{
		Meta: &taskletv2.ExecutionMeta{
			Id:        uuid.Must(uuid.NewV4()).String(),
			TaskletId: tasklet.Meta.Id,
			BuildId:   targetBuildID,
			CreatedAt: timestamppb.New(time.Now()),
		},
		Spec: &taskletv2.ExecutionSpec{
			Author:          userAuth.Login(),
			ReferencedLabel: labelName,
			Requirements: &taskletv2.ExecutionRequirements{
				AccountId: request.GetRequirements().GetAccountId(),
			},
			Input: request.GetInput(),
		},
		Status: &taskletv2.ExecutionStatus{
			Status: taskletv2.EExecutionStatus_E_EXECUTION_STATUS_EXECUTING,
		},
	}

	if ytSpawn := requestctx.GetFeature(ctx, "yt_spawn"); ytSpawn != nil {
		if doSpawn, ok := ytSpawn.(bool); ok && doSpawn {
			if annotations, err := structpb.NewStruct(nil); err != nil {
				panic(err)
			} else {
				execution.Status.Annotations = annotations
			}
			execution.Status.Annotations.Fields["yt_spawn"] = structpb.NewBoolValue(true)
		}
	}

	if st := xmodels.ValidateExecution(execution); st != nil {
		return nil, st.Err()
	}

	// NB: check input parsing
	// FIXME: cache schemas & proto resolvers
	{
		build, err := t.db.GetBuild(ctx, targetBuildID)
		if err != nil {
			return nil, err
		}
		ioSchema := build.GetSpec().GetSchema().GetSimpleProto()
		schema, err := t.db.GetSchema(ctx, ioSchema.GetSchemaHash())
		if err != nil {
			if errors.Is(err, common.ErrObjectNotFound) {
				ctxlog.Errorf(
					ctx,
					t.Log,
					"IO schema not registered. BuildID: %q, SchemaID: %q",
					targetBuildID,
					ioSchema.SchemaHash,
				)
				return nil, status.Errorf(
					codes.Internal,
					"IO schema not registered. SchemaID: %q",
					ioSchema.SchemaHash,
				)
			}
			return nil, err
		}

		msg, err := prototools.ParseMessage(schema.Fds, ioSchema.GetInputMessage(), execution.Spec.Input)
		if err != nil {
			return nil, status.Errorf(codes.FailedPrecondition, "Failed to unmarshall input: %v", err)
		}

		// NB: tbd check required secrets
		_ = msg

	}

	rv := &taskletv2.ExecuteResponse{}
	if createdEx, err := t.db.AddExecution(ctx, requestctx.GetRequestID(ctx), execution); err != nil {
		return nil, err
	} else {
		rv.Execution = createdEx
	}
	return rv, nil

}

var errAlreadyAborted = xerrors.NewSentinel("already aborted")

func (t *APIHandler) AbortExecution(
	ctx context.Context,
	request *taskletv2.AbortExecutionRequest,
) (*taskletv2.AbortExecutionResponse, error) {
	userAuth, err := t.requireUserAuth(ctx)
	if err != nil {
		return nil, err
	}

	if st := xmodels.ValidateAbortExecutionRequest(request); st != nil {
		return nil, st.Err()
	}
	execution, err := t.db.GetExecution(ctx, request.GetId())
	if err != nil {
		if errors.Is(err, common.ErrObjectNotFound) {
			return nil, status.Errorf(codes.NotFound, "Execution does not exist. ID: %q", request.GetId())
		}
		return nil, err
	}

	if execution.Status.Status == taskletv2.EExecutionStatus_E_EXECUTION_STATUS_FINISHED {
		return nil, status.Error(codes.InvalidArgument, "Execution already processed")
	}

	if execution.Status.GetAbortRequest().GetAborted() {
		return nil, status.Errorf(codes.FailedPrecondition, "Execution already aborted")
	}

	tasklet, err := t.db.GetTaskletByID(ctx, consts.TaskletID(execution.Meta.TaskletId))
	if err != nil {
		return nil, err
	}
	ns, err := t.findNamespace(ctx, tasklet.Meta.Namespace, false)
	if err != nil {
		return nil, err
	}

	// NB: ACL check
	{
		ad := acmodel.AccessData{Namespace: ns, Tasklet: tasklet}
		permissions := ad.GetPermissions()
		permissions = append(
			permissions, &taskletv2.PermissionsSubject{
				Source: taskletv2.PermissionsSubject_E_SOURCE_USER,
				Name:   execution.Spec.Author,
				Roles:  []string{string(acmodel.TaskletWrite)},
			},
		)

		if permitted, err := t.permissionsChecker.CheckCommonPermission(
			ctx,
			userAuth.Login(),
			acmodel.BaseTaskletWrite,
			permissions,
		); err != nil {
			return nil, err
		} else if !permitted {
			return nil, status.Errorf(
				codes.PermissionDenied,
				"User does not have permission %v on execution",
				acmodel.BaseTaskletWrite,
			)
		}
	}

	// NB: do abort
	abortRequest := &taskletv2.AbortRequest{
		Aborted:   true,
		Author:    userAuth.Login(),
		Reason:    request.Reason,
		AbortedAt: timestamppb.Now(),
	}
	update := func(st *taskletv2.ExecutionStatus) error {
		if st.GetAbortRequest().GetAborted() {
			return errAlreadyAborted
		}
		st.AbortRequest = abortRequest
		return nil
	}

	if _, err := t.db.UpdateExecutionStatus(ctx, consts.ExecutionID(request.Id), update); err != nil {
		if xerrors.Is(err, common.ErrObjectNotFound) {
			// NB: execution got archived. We've checked its existence earlier
			return nil, status.Errorf(codes.FailedPrecondition, "Execution already processed")
		}

		if xerrors.Is(err, errAlreadyAborted) {
			return nil, status.Errorf(codes.FailedPrecondition, "Execution already aborted")
		}
		return nil, err
	}

	return &taskletv2.AbortExecutionResponse{}, nil
}

// postProcessExecution restored deprecated fields status.error & status.result
func postProcessExecution(e *taskletv2.Execution) {
	result := e.Status.ProcessingResult
	if result == nil {
		return
	}
	switch v := result.Kind.(type) {
	case *taskletv2.ProcessingResult_ServerError:
		e.Status.Error = &taskletv2.ExecutionError{
			Description: fmt.Sprintf("Server error: %v", v.ServerError.Description),
		}
	case *taskletv2.ProcessingResult_Output:
		e.Status.Result = proto.Clone(v.Output).(*taskletv2.ExecutionOutput)
	case *taskletv2.ProcessingResult_UserError:
		e.Status.Error = &taskletv2.ExecutionError{
			Description: fmt.Sprintf("User error: %v", v.UserError.Description),
		}
	default:
		panic(fmt.Sprintf("Bad result kind: %T", result.Kind))
	}
}

func (t *APIHandler) GetExecution(
	ctx context.Context,
	request *taskletv2.GetExecutionRequest,
) (*taskletv2.GetExecutionResponse, error) {
	if st := xmodels.ValidateGetExecutionRequest(request); st != nil {
		return nil, st.Err()
	}
	execution, err := t.db.GetExecution(ctx, request.GetId())
	if err != nil {
		if errors.Is(err, common.ErrObjectNotFound) {
			return nil, status.Errorf(codes.NotFound, "Execution does not exist. ID: %q", request.GetId())
		}
		return nil, err
	}

	// FIXME: check ACL
	postProcessExecution(execution)
	resp := taskletv2.GetExecutionResponse{
		Execution: execution,
	}
	return &resp, nil
}

func (t *APIHandler) ListExecutionsByTasklet(
	ctx context.Context,
	request *taskletv2.ListExecutionsByTaskletRequest,
) (*taskletv2.ListExecutionsByTaskletResponse, error) {
	userAuth, err := t.requireUserAuth(ctx)
	if err != nil {
		return nil, err
	}

	if st := xmodels.ValidateListExecutionsByTaskletRequest(request); st != nil {
		return nil, st.Err()
	}

	tasklet, err := t.findTasklet(ctx, request.Tasklet, request.Namespace, false)
	if err != nil {
		return nil, err
	}
	ns, err := t.findNamespace(ctx, tasklet.Meta.Namespace, false)
	if err != nil {
		return nil, err
	}

	if err = t.permissionsChecker.CheckPermissions(
		ctx, userAuth.Login(), acmodel.BaseTaskletRead, &acmodel.AccessData{Namespace: ns, Tasklet: tasklet},
	); err != nil {
		return nil, err
	}
	executionsList, err := t.db.ListExecutionsByTasklet(ctx, tasklet.Meta.Id, request.GetToken())
	if err != nil {
		return nil, err
	}
	for _, execution := range executionsList.Executions {
		postProcessExecution(execution)
	}
	return &taskletv2.ListExecutionsByTaskletResponse{
		Executions: executionsList.Executions,
		Token:      executionsList.Token,
	}, nil
}

func (t *APIHandler) ListExecutionsByBuild(
	ctx context.Context,
	request *taskletv2.ListExecutionsByBuildRequest,
) (*taskletv2.ListExecutionsByBuildResponse, error) {

	if st := xmodels.ValidateListExecutionsByBuildRequest(request); st != nil {
		return nil, st.Err()
	}
	if _, err := t.findBuild(ctx, request.BuildId, false); err != nil {
		return nil, err
	}

	// FIXME: check acl

	executionsList, err := t.db.ListExecutionsByBuild(ctx, request.GetBuildId(), request.GetToken())
	if err != nil {
		return nil, err
	}
	for _, execution := range executionsList.Executions {
		postProcessExecution(execution)
	}

	return &taskletv2.ListExecutionsByBuildResponse{
		Executions: executionsList.Executions,
		Token:      executionsList.Token,
	}, nil
}
