package processor

import (
	"context"
	"fmt"

	"google.golang.org/protobuf/types/known/timestamppb"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/ctxlog"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/sandbox/common/go/models"

	taskletv2 "a.yandex-team.ru/tasklet/api/v2"
	"a.yandex-team.ru/tasklet/experimental/internal/consts"
	"a.yandex-team.ru/tasklet/experimental/internal/handler/xmodels"
	"a.yandex-team.ru/tasklet/experimental/internal/state"
	"a.yandex-team.ru/tasklet/experimental/internal/storage/common"
	"a.yandex-team.ru/tasklet/experimental/internal/yandex/sandbox"
)

var (
	errStatusUpdateFailed = xerrors.NewSentinel("Status update failed")
	errUnexpectedError    = xerrors.NewSentinel("Unexpected error")
)

type executionHandler struct {
	e               *taskletv2.Execution
	b               *taskletv2.Build
	p               *Processor
	done            bool
	err             error
	externalSession sandbox.SandboxExternalSession
}

func newExecutionHandler(e *taskletv2.Execution, p *Processor) *executionHandler {
	return &executionHandler{
		e:    e,
		b:    nil,
		p:    p,
		done: false,
	}
}

func (eh *executionHandler) id() consts.ExecutionID {
	return consts.ExecutionID(eh.e.Meta.Id)
}

func (eh *executionHandler) AbortRequested() bool {
	return eh.e.Status.GetAbortRequest().GetAborted()
}

func (eh *executionHandler) processorStatus() *taskletv2.ExecutionProcessorStatus {
	return eh.e.Status.GetProcessor()
}

func (eh *executionHandler) setErr(err error) {
	if eh.done {
		eh.p.logger.Infof("Drop non-first error: %v", err)
		return
	}
	if eh.err != nil {
		eh.p.logger.Errorf("Error already set. Old: %+v, New: %+v", eh.err, err)
	}
	eh.err = err
	eh.done = true
}

func (eh *executionHandler) mustGetBuild(ctx context.Context) {
	if eh.b != nil {
		return
	}

	build, err := eh.p.storage.GetBuild(ctx, eh.e.Meta.BuildId)
	if err != nil {
		eh.setErr(xerrors.Errorf("GetBuild failed: %w", err))
		return
	}

	if status := xmodels.ValidateBuild(build); status.Err() != nil {
		eh.setErr(
			xerrors.Errorf(
				"Execution %q uses invalid build %q. Details: %+v",
				eh.e.Meta.Id,
				eh.e.Meta.BuildId, status.Details(),
			),
		)
		return
	}
	eh.b = build
}

func (eh *executionHandler) applyStatusUpdates(ctx context.Context, updates ...common.ExecutionStatusUpdateFunc) {
	if newEx, err := eh.p.storage.UpdateExecutionStatus(ctx, eh.id(), updates...); err != nil {
		ctxlog.Error(ctx, eh.p.logger, "Status update failed", log.Error(err))
		eh.setErr(errStatusUpdateFailed.Wrap(err))
	} else {
		eh.e = newEx
	}
}

func (eh *executionHandler) stepSpawnYT(ctx context.Context) {
	if eh.done {
		return
	}

	if eh.processorStatus().GetYtOperationId() != "" {
		return
	}

	ctxlog.Info(ctx, eh.p.logger, "YT spawn started")
	defer func() {
		if eh.err != nil {
			ctxlog.Error(ctx, eh.p.logger, "YT spawn failed", log.Error(eh.err))
		} else {
			ctxlog.Info(ctx, eh.p.logger, "YT spawn success")
		}
	}()

	eh.mustGetBuild(ctx)
	if eh.done {
		return
	}

	ctxlog.Info(ctx, eh.p.logger, "Acquiring external session")
	secureSession, err := eh.p.TryAcquireExternalSession(ctx, eh.id(), eh.e.Spec.Author)
	if err != nil {
		eh.setErr(err)
		return
	}
	ctxlog.Info(ctx, eh.p.logger, "Registering execution in driver")
	if err := eh.p.ytc.RegisterExecution(ctx, eh.b, eh.e, secureSession); err != nil {
		eh.setErr(err)
		return
	}

	ctxlog.Info(ctx, eh.p.logger, "Spawning yt operation")
	opID, spawnErr := eh.p.ytc.Spawn(ctx, eh.id())
	message := "OK"
	if spawnErr != nil {
		eh.setErr(spawnErr)
		message = spawnErr.Error()
	} else {
		ctxlog.Infof(ctx, eh.p.logger, "Spawned yt operation. OpID: %q", opID)
	}
	newProcessorStatus := &taskletv2.ExecutionProcessorStatus{
		YtOperationId: opID,
		UpdatedAt:     timestamppb.Now(),
		Message:       message,
	}

	update := func(status *taskletv2.ExecutionStatus) error {
		oldProcessorStatus := status.Processor
		if oldProcessorStatus != nil {
			oldOperationID := oldProcessorStatus.YtOperationId
			if oldOperationID != "" && newProcessorStatus.YtOperationId == "" {
				// Preserve operation ID
				newProcessorStatus.YtOperationId = oldOperationID
			}
			if oldOperationID != "" && newProcessorStatus.YtOperationId != oldOperationID {
				return xerrors.Errorf(
					"Attempt to change YT operation id. ExecutionID: %q, OldTaskID: %v, NewTaskID: %v",
					eh.id(),
					oldOperationID,
					newProcessorStatus.YtOperationId,
				)
			}
		}
		status.Processor = newProcessorStatus
		return nil
	}

	eh.applyStatusUpdates(ctx, update)

	// Stop execution handling. no need to check just spawned operation
	if spawnErr == nil {
		eh.done = true
	}
}

func (eh *executionHandler) stepAwaitYT(ctx context.Context) {
	if eh.done {
		return
	}

	// Nothing to do, possibly pending archivation or session close
	if eh.e.Status.Status == taskletv2.EExecutionStatus_E_EXECUTION_STATUS_FINISHED {
		return
	}

	operationID := eh.processorStatus().GetYtOperationId()
	if operationID == "" {
		err := errUnexpectedError.Wrap(xerrors.New("empty operation id"))
		eh.setErr(err)
		update := func(st *taskletv2.ExecutionStatus) error {
			st.Processor.Message = err.Error()
			st.Processor.UpdatedAt = timestamppb.Now()
			return nil
		}
		eh.applyStatusUpdates(ctx, update)
		return
	}

	ctxlog.Info(ctx, eh.p.logger, "YT check started")
	defer func() {
		if eh.err != nil {
			ctxlog.Error(ctx, eh.p.logger, "YT check failed", log.Error(eh.err))
		} else {
			ctxlog.Info(ctx, eh.p.logger, "YT check success")
		}
	}()

	operationResult, checkError := eh.p.ytc.CheckOperationStatus(ctx, operationID)

	if checkError != nil {
		eh.setErr(checkError)

		update := func(st *taskletv2.ExecutionStatus) error {
			st.Processor.Message = fmt.Sprintf("Error: %+v", checkError)
			st.Processor.UpdatedAt = timestamppb.Now()
			return nil
		}
		eh.applyStatusUpdates(ctx, update)
		return
	}

	if operationResult.Finished {
		ctxlog.Info(ctx, eh.p.logger, "Operation finished. Closing external session")
		// Close external session early
		if errClose := eh.p.TryCloseExternalSession(ctx, eh.id(), 0); errClose != nil {
			err := xerrors.Errorf("External session close failed: %w", errClose)
			eh.setErr(err)

			update := func(st *taskletv2.ExecutionStatus) error {
				st.Processor.Message = fmt.Sprintf("Error: %+v", err)
				st.Processor.UpdatedAt = timestamppb.Now()
				return nil
			}
			eh.applyStatusUpdates(ctx, update)
			return
		}
	}

	updateOps := make([]common.ExecutionStatusUpdateFunc, 0, 1)
	updateOps = append(
		updateOps,
		func(st *taskletv2.ExecutionStatus) error {
			st.Processor.Message = fmt.Sprintf("Operation state: %s", operationResult.State)
			st.Processor.UpdatedAt = timestamppb.Now()
			return nil
		},
	)

	if operationResult.Finished {
		updateOps = append(
			updateOps,
			func(st *taskletv2.ExecutionStatus) error {
				st.Status = taskletv2.EExecutionStatus_E_EXECUTION_STATUS_FINISHED
				return nil
			},

			func(st *taskletv2.ExecutionStatus) error {
				if st.ProcessingResult != nil {
					return nil
				}

				errorMsg := &taskletv2.ServerError{}
				if operationResult.IsError {
					errorMsg.Code = taskletv2.ErrorCodes_ERROR_CODE_GENERIC
					errorMsg.Description = operationResult.ErrorSummary
				} else {
					errorMsg.Code = taskletv2.ErrorCodes_ERROR_CODE_NO_RESPONSE
					errorMsg.Description = "executor failed to report result"
				}

				st.ProcessingResult = &taskletv2.ProcessingResult{
					Kind: &taskletv2.ProcessingResult_ServerError{
						ServerError: errorMsg,
					},
				}
				if st.Stats == nil {
					st.Stats = &taskletv2.ExecutionStats{FinishedAt: timestamppb.Now()}
				}
				return nil
			},
		)
	}

	eh.applyStatusUpdates(ctx, updateOps...)
}

func (eh *executionHandler) stepCreateSandbox(ctx context.Context) {
	if eh.done {
		return
	}

	if eh.processorStatus().GetSandboxTaskId() != 0 {
		return
	}

	if eh.AbortRequested() {
		// NB: Sandbox task may be created by now in DRAFT state. No need to persist it?
		return
	}

	ctxlog.Info(ctx, eh.p.logger, "Sandbox create started")
	defer func() {
		if eh.err != nil {
			ctxlog.Error(ctx, eh.p.logger, "Sandbox create failed", log.Error(eh.err))
		} else {
			ctxlog.Info(ctx, eh.p.logger, "Sandbox create success")
		}
	}()

	eh.mustGetBuild(ctx)
	if eh.done {
		return
	}

	sbTaskID, err := eh.p.TryCreateSandboxTask(ctx, eh.e, eh.b)

	if err != nil {
		err = errUnexpectedError.Wrap(err)
		eh.setErr(err)
		update := func(st *taskletv2.ExecutionStatus) error {
			st.Processor.Message = err.Error()
			st.Processor.UpdatedAt = timestamppb.Now()
			return nil
		}
		eh.applyStatusUpdates(ctx, update)
		return
	}

	newProcessorStatus := &taskletv2.ExecutionProcessorStatus{
		Message:       "OK: Created sandbox task",
		UpdatedAt:     timestamppb.Now(),
		SandboxTaskId: sbTaskID.ToInt(),
	}

	update := func(status *taskletv2.ExecutionStatus) error {
		oldProcessorStatus := status.Processor
		if oldProcessorStatus != nil {
			oldTaskID := oldProcessorStatus.SandboxTaskId
			if oldTaskID != 0 && newProcessorStatus.SandboxTaskId == 0 {
				newProcessorStatus.SandboxTaskId = oldTaskID
			}
			if oldTaskID != 0 && newProcessorStatus.SandboxTaskId != oldTaskID {
				return xerrors.Errorf(
					"Attempt to change sandbox task id. ExecutionID: %q, OldTaskID: %v, NewTaskID: %v",
					eh.id(),
					oldTaskID,
					newProcessorStatus.SandboxTaskId,
				)
			}
		}
		status.Processor = newProcessorStatus
		return nil
	}
	eh.applyStatusUpdates(ctx, update)
}

func (eh *executionHandler) stepStartSandbox(ctx context.Context) {
	if eh.done {
		return
	}
	// Nothing to do, possibly pending archivation or session close
	if eh.e.Status.Status == taskletv2.EExecutionStatus_E_EXECUTION_STATUS_FINISHED {
		return
	}

	if eh.AbortRequested() {
		return
	}

	sbTaskID := sandbox.SandboxTaskID(eh.processorStatus().SandboxTaskId)
	if sbTaskID == 0 {
		err := errUnexpectedError.Wrap(xerrors.New("empty task id"))
		eh.setErr(err)
		update := func(st *taskletv2.ExecutionStatus) error {
			st.Processor.Message = err.Error()
			st.Processor.UpdatedAt = timestamppb.Now()
			return nil
		}
		eh.applyStatusUpdates(ctx, update)
		return
	}

	if eh.processorStatus().SandboxStarted {
		return
	}

	ctxlog.Info(ctx, eh.p.logger, "Sandbox start started")
	defer func() {
		if eh.err != nil {
			ctxlog.Error(ctx, eh.p.logger, "Sandbox start failed", log.Error(eh.err))
		} else {
			ctxlog.Info(ctx, eh.p.logger, "Sandbox start success")
		}
	}()

	err := eh.p.TryCreateExternalSession(ctx, eh.id(), eh.e.Spec.Author, sbTaskID)
	if err != nil {
		eh.setErr(err)
		return
	}

	errStart := eh.p.TryStartSandboxTask(ctx, sbTaskID)
	if errStart != nil {
		if xerrors.Is(errStart, sandbox.ErrSandboxTaskStart) {
			// NB: handle permanent start error
			registerErr := eh.registerSandboxExecutionResult(ctx, errStart.Error())
			if registerErr != nil {
				eh.setErr(registerErr)
			}
		} else {
			outerErr := errUnexpectedError.Wrap(errStart)
			eh.setErr(outerErr)
			update := func(st *taskletv2.ExecutionStatus) error {
				st.Processor.Message = outerErr.Error()
				st.Processor.UpdatedAt = timestamppb.Now()
				return nil
			}
			eh.applyStatusUpdates(ctx, update)
		}

		return
	}

	update := func(status *taskletv2.ExecutionStatus) error {
		oldProcessorStatus := status.Processor
		if oldProcessorStatus != nil {
			oldTaskID := oldProcessorStatus.SandboxTaskId
			if oldTaskID == 0 {
				return xerrors.Errorf(
					"Missing sandbox task id. ExecutionID: %q, ExpectedTaskID: %v",
					eh.id(),
					sbTaskID.ToInt(),
				)
			}
		}
		oldProcessorStatus.Message = "OK: Started sandbox task"
		oldProcessorStatus.UpdatedAt = timestamppb.Now()
		oldProcessorStatus.SandboxStarted = true
		return nil
	}
	eh.applyStatusUpdates(ctx, update)
}

func (eh *executionHandler) registerSandboxExecutionResult(ctx context.Context, status string) error {
	sbTaskID := sandbox.SandboxTaskID(eh.processorStatus().SandboxTaskId)
	if sbTaskID == 0 && !eh.AbortRequested() {
		return xerrors.New("unexpected: empty sandbox task id")
	}

	// Session may be already closed while handling abortion request
	if errClose := eh.p.TryCloseExternalSession(ctx, eh.id(), sbTaskID); errClose != nil {
		return errClose
	}

	updateOps := []common.ExecutionStatusUpdateFunc{
		func(st *taskletv2.ExecutionStatus) error {
			st.Processor.Message = status
			st.Processor.UpdatedAt = timestamppb.Now()
			return nil
		},
		func(st *taskletv2.ExecutionStatus) error {
			if st.Status != taskletv2.EExecutionStatus_E_EXECUTION_STATUS_EXECUTING {
				// FIXME: return error?
				ctxlog.Errorf(ctx, eh.p.logger, "Unexpected execution status. Status: %q", st.Status.String())
			}
			st.Status = taskletv2.EExecutionStatus_E_EXECUTION_STATUS_FINISHED
			return nil
		},
		func(st *taskletv2.ExecutionStatus) error {

			if st.ProcessingResult != nil {
				// Got response from executor
				return nil
			}

			errorMsg := &taskletv2.ServerError{}
			if eh.AbortRequested() {
				errorMsg.Code = taskletv2.ErrorCodes_ERROR_CODE_ABORTED
				errorMsg.Description = "Execution aborted on user request"
				errorMsg.AbortInfo = &taskletv2.AbortionInfo{
					Author:      eh.e.Status.GetAbortRequest().GetAuthor(),
					Reason:      eh.e.Status.GetAbortRequest().GetReason(),
					RequestedAt: timestamppb.New(eh.e.Status.GetAbortRequest().GetAbortedAt().AsTime()),
					CompletedAt: timestamppb.Now(),
				}
			} else {
				errorMsg.Code = taskletv2.ErrorCodes_ERROR_CODE_NO_RESPONSE
				errorMsg.Description = "No response from tasklet executor due to executor or task crash"
			}

			st.ProcessingResult = &taskletv2.ProcessingResult{
				Kind: &taskletv2.ProcessingResult_ServerError{
					ServerError: errorMsg,
				},
			}
			if st.Stats == nil {
				st.Stats = &taskletv2.ExecutionStats{FinishedAt: timestamppb.Now()}
			}
			return nil
		},
	}
	eh.applyStatusUpdates(ctx, updateOps...)
	return nil
}

func (eh *executionHandler) stepHandleAbortionSandbox(ctx context.Context) {
	if eh.done {
		return
	}
	if eh.e.Status.Status == taskletv2.EExecutionStatus_E_EXECUTION_STATUS_FINISHED {
		// Nothing to do, possibly pending archivation or session close or earlier errors
		return
	}

	if !eh.AbortRequested() {
		// Nothing to do
		return
	}

	ctxlog.Info(ctx, eh.p.logger, "Sandbox abortion started")
	defer func() {
		if eh.err != nil {
			ctxlog.Error(ctx, eh.p.logger, "Sandbox abortion failed", log.Error(eh.err))
		} else {
			ctxlog.Info(ctx, eh.p.logger, "Sandbox abortion success")
		}
	}()

	sbTaskID := sandbox.SandboxTaskID(eh.processorStatus().SandboxTaskId)
	if sbTaskID == 0 {
		finalizeErr := eh.registerSandboxExecutionResult(ctx, "Aborted prior to task spawn")
		if finalizeErr != nil {
			eh.setErr(finalizeErr)
		}
		return
	}

	// NB: StopTask returns nil if task is in terminal state
	if err := eh.p.TryStopSandboxTask(ctx, eh.id(), sbTaskID); err != nil {
		err = xerrors.Errorf("Failed top stop sandbox task: %w", err)
		eh.setErr(err)
		update := func(st *taskletv2.ExecutionStatus) error {
			st.Processor.Message = fmt.Sprintf("Error: %+v", err)
			st.Processor.UpdatedAt = timestamppb.Now()
			return nil
		}
		eh.applyStatusUpdates(ctx, update)
		return
	}

	if err := eh.p.TryCloseExternalSession(ctx, eh.id(), sbTaskID); err != nil {
		err = xerrors.Errorf("failed to close external session: %w", err)
		eh.setErr(err)
		update := func(st *taskletv2.ExecutionStatus) error {
			st.Processor.Message = fmt.Sprintf("Error: %+v", err)
			st.Processor.UpdatedAt = timestamppb.Now()
			return nil
		}
		eh.applyStatusUpdates(ctx, update)
		return
	}

	// NB: chain task awaiting to next stage
}

func (eh *executionHandler) stepAwaitSandbox(ctx context.Context) {

	if eh.done {
		return
	}
	if eh.e.Status.Status == taskletv2.EExecutionStatus_E_EXECUTION_STATUS_FINISHED {
		// Nothing to do, possibly pending archivation or session close or earlier errors
		return
	}

	sbTaskID := sandbox.SandboxTaskID(eh.processorStatus().SandboxTaskId)
	if sbTaskID == 0 {
		err := errUnexpectedError.Wrap(xerrors.New("empty task id"))
		eh.setErr(err)
		update := func(st *taskletv2.ExecutionStatus) error {
			st.Processor.Message = err.Error()
			st.Processor.UpdatedAt = timestamppb.Now()
			return nil
		}
		eh.applyStatusUpdates(ctx, update)
		return
	}
	ctxlog.Info(ctx, eh.p.logger, "Sandbox check started")
	defer func() {
		if eh.err != nil {
			ctxlog.Error(ctx, eh.p.logger, "Sandbox check failed", log.Error(eh.err))
		} else {
			ctxlog.Info(ctx, eh.p.logger, "Sandbox check success")
		}
	}()

	// Task exists, but maybe aborted
	sbStatus, finished, _, checkError := eh.p.sbx.GetSandboxTaskStatus(ctx, sbTaskID)
	if checkError != nil {
		eh.setErr(checkError)
		update := func(st *taskletv2.ExecutionStatus) error {
			st.Processor.Message = fmt.Sprintf("Error: %+v", checkError)
			st.Processor.UpdatedAt = timestamppb.Now()
			return nil
		}
		eh.applyStatusUpdates(ctx, update)
		return
	}

	if sbStatus == models.TaskAuditItemStatusDRAFT {
		finished = true
		if !eh.AbortRequested() {
			ctxlog.Error(ctx, eh.p.logger, "Awaiting task in DRAFT status")
		}
	}

	if finished {
		finalizeErr := eh.registerSandboxExecutionResult(ctx, fmt.Sprintf("Sandbox status: %v", sbStatus))
		if finalizeErr != nil {
			eh.setErr(finalizeErr)
		}
		return
	} else {
		updateOps := []common.ExecutionStatusUpdateFunc{
			func(st *taskletv2.ExecutionStatus) error {
				st.Processor.Message = fmt.Sprintf("Sandbox status: %v", sbStatus)
				st.Processor.UpdatedAt = timestamppb.Now()
				return nil
			},
		}
		eh.applyStatusUpdates(ctx, updateOps...)
	}
}

func (eh *executionHandler) stepArchiveExecution(ctx context.Context) {
	if eh.done {
		return
	}
	if eh.e.Status.GetStatus() != taskletv2.EExecutionStatus_E_EXECUTION_STATUS_FINISHED {
		return
	}

	ctxlog.Info(ctx, eh.p.logger, "Archive execution started")
	defer func() {
		if eh.err != nil {
			ctxlog.Error(ctx, eh.p.logger, "Archive execution failed", log.Error(eh.err))
		} else {
			ctxlog.Info(ctx, eh.p.logger, "Archive execution success")
		}
	}()

	if err := eh.p.storage.ArchiveExecution(ctx, eh.id()); err != nil {
		eh.setErr(err)
	} else {
		// NB: expected to be terminal state
		eh.done = true
	}
}

func (eh *executionHandler) stepResolveResources(ctx context.Context) {
	if eh.done {
		return
	}

	if eh.e.Status.GetResources().GetResolved() {
		return
	}
	ctxlog.Info(ctx, eh.p.logger, "Resource resolve started")
	defer func() {
		if eh.err != nil {
			ctxlog.Error(ctx, eh.p.logger, "Resource resolve failed", log.Error(eh.err))
		} else {
			ctxlog.Info(ctx, eh.p.logger, "Resource resolve success")
		}
	}()

	eh.mustGetBuild(ctx)
	if eh.done {
		return
	}

	resources := &taskletv2.ExecutionResources{
		Resolved:  true,
		Resources: make(map[string]*taskletv2.SandboxResource),
	}

	if eh.b.GetSpec().GetEnvironment().GetArcClient().GetEnabled() {
		res, err := state.SandboxState.GetResource(consts.ArcClientResourceType)
		if err != nil {
			eh.setErr(xerrors.Errorf("cannot find %q: %w", consts.ArcClientResourceType, err))
			return
		} else {
			resources.Resources[consts.ArcClientResourceType.String()] = &taskletv2.SandboxResource{ResourceId: res.ID}
		}
	}

	if eh.b.GetSpec().GetEnvironment().GetSandboxResourceManager().GetEnabled() {
		res, err := state.SandboxState.GetResource(consts.SandboxResourceManagerType)
		if err != nil {
			eh.setErr(xerrors.Errorf("cannot find %q: %w", consts.SandboxResourceManagerType, err))
			return
		} else {
			resources.Resources[string(consts.SandboxResourceManagerType)] = &taskletv2.SandboxResource{ResourceId: res.ID}
		}
	}

	update := func(status *taskletv2.ExecutionStatus) error {
		if status == nil {
			return xerrors.New("Got empty status")
		}
		if status.GetResources().GetResolved() {
			return xerrors.New("Resources already resolved")
		}
		status.Resources = resources
		return nil
	}
	eh.applyStatusUpdates(ctx, update)
}
