package executor

import (
	"context"
	"errors"
	"fmt"
	"os"
	"os/exec"
	"time"

	"google.golang.org/protobuf/encoding/prototext"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/types/known/durationpb"
	"google.golang.org/protobuf/types/known/timestamppb"

	"golang.org/x/exp/slices"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/ctxlog"
	"a.yandex-team.ru/library/go/core/xerrors"

	"a.yandex-team.ru/tasklet/api/priv/v1"
	"a.yandex-team.ru/tasklet/api/v2"
	"a.yandex-team.ru/tasklet/experimental/internal/consts"
)

func constructTaskletCommand(env *TaskletEnvironment, launchSpec *taskletv2.LaunchSpec) (
	binPath string,
	args []string,
) {
	launchType := launchSpec.GetType()
	if slices.Contains(consts.LaunchJavaTypes, launchType) {
		binPath = env.JavaBinPath
		args = append(args, "-cp", env.TaskletPath, launchSpec.GetJdk().GetMainClass())
	} else {
		binPath = env.TaskletPath
	}
	args = append(
		args,
		env.ctxProvider.GetContext().Executor.Address,
		env.P.InputFilePath(),
		env.P.OutputFilePath(),
		env.P.ErrorFilePath(),
	)
	return
}

func getProcessEnvVars(te *TaskletEnvironment) []string {
	var env = make([]string, 0, len(te.EnvVars))
	for key, value := range te.EnvVars {
		env = append(env, key+"="+value)
	}
	return env
}

type ExecutionOutcome struct {
	stats  *taskletv2.ExecutionStats
	result *taskletv2.ProcessingResult
}

// TODO: Take execution timeout from Build.Spec
var subprocessTTL = 600 * time.Second

func readSerializedProto(ctx context.Context, filePath string, message proto.Message) ([]byte, error) {
	data, errRead := os.ReadFile(filePath)
	if errRead != nil {
		return nil, xerrors.Errorf("Failed read: %w", errRead)
	}

	if errParse := proto.Unmarshal(data, message); errParse != nil {
		return nil, xerrors.Errorf("Failed to unmarshall: %w", errParse)
	}
	return data, nil
}

// runTaskletSubprocess executes user payload and returns payload error in outcome.error.
// return error in case of tasklet infrastructure error
func runTaskletSubprocess(
	ctx context.Context,
	env *TaskletEnvironment,
	info *executionInfo,
	logger log.Logger,
) (*ExecutionOutcome, error) {

	if err := os.WriteFile(env.P.InputFilePath(), info.getSerializedInput(), 0664); err != nil {
		return nil, err
	}

	executable, commandArgs := constructTaskletCommand(env, info.build.GetSpec().GetLaunchSpec())

	var stdout, stderr *os.File
	var err error

	closeFileAndLogError := func(file *os.File) {
		if err := file.Close(); err != nil {
			ctxlog.Errorf(ctx, logger, "Error on closing %q: %+v", file.Name(), err)
		}
	}

	if stdout, err = os.Create(env.P.StdoutFilePath()); err != nil {
		return nil, err
	}
	defer closeFileAndLogError(stdout)
	if stderr, err = os.Create(env.P.StderrFilePath()); err != nil {
		return nil, err
	}
	defer closeFileAndLogError(stderr)

	subprocessCtx, cancel := context.WithTimeout(ctx, subprocessTTL)
	defer cancel()

	subprocess := exec.CommandContext(subprocessCtx, executable, commandArgs...)
	ctxlog.Infof(ctx, logger, "Starting subprocess %v with timeout %v", subprocess, subprocessTTL)
	subprocess.Env = getProcessEnvVars(env)
	subprocess.Dir = env.P.TaskletSandbox
	subprocess.Stdout = stdout
	subprocess.Stderr = stderr

	startTime := time.Now()
	if err = subprocess.Start(); err != nil {
		return nil, xerrors.Errorf("failed to start subprocess: %w", err)
	}

	ctxlog.Infof(ctx, logger, "Started subprocess with pid=%v", subprocess.Process.Pid)

	// NB: Wait exists upon parent context cancel or subprocess stop
	subprocessErr := subprocess.Wait()
	finishTime := time.Now()
	ctxlog.Infof(ctx, logger, "Subprocess finished. Duration: %v", finishTime.Sub(startTime))

	result := &ExecutionOutcome{
		result: &taskletv2.ProcessingResult{},
		stats: &taskletv2.ExecutionStats{
			CreatedAt:  timestamppb.New(startTime),
			FinishedAt: timestamppb.New(finishTime),
			Duration:   durationpb.New(finishTime.Sub(startTime)),
			ExitCode:   int32(subprocess.ProcessState.ExitCode()),
		},
	}

	ctxlog.Infof(
		ctx,
		logger,
		"Subprocess #%d finished with code %d",
		subprocess.Process.Pid,
		subprocess.ProcessState.ExitCode(),
	)

	if subprocessErr != nil {
		// Check for root context abortion. This is abnormal, report generic error for now
		if ctx.Err() != nil {
			result.result.Kind = &taskletv2.ProcessingResult_ServerError{
				ServerError: &taskletv2.ServerError{
					Code:        taskletv2.ErrorCodes_ERROR_CODE_GENERIC,
					Description: fmt.Sprintf("Aborted: %v", ctx.Err()),
				},
			}
			return result, nil
		}

		// NB: User payload failure. Fill error and return
		if errors.Is(subprocessCtx.Err(), context.DeadlineExceeded) {

			errMsg := fmt.Sprintf("Subprocess %v was killed by timeout %v", subprocess.Process.Pid, subprocessTTL)
			result.result.Kind = &taskletv2.ProcessingResult_ServerError{
				ServerError: &taskletv2.ServerError{
					Code:        taskletv2.ErrorCodes_ERROR_CODE_TIMEOUT,
					Description: errMsg,
				},
			}
			return result, nil
		} else if err := subprocessCtx.Err(); err != nil {
			ctxlog.Error(ctx, logger, "Subprocess error", log.Error(err))
			result.result.Kind = &taskletv2.ProcessingResult_ServerError{
				ServerError: &taskletv2.ServerError{
					Code:        taskletv2.ErrorCodes_ERROR_CODE_CRASHED,
					Description: fmt.Sprintf("Subprocess error: %+v", err),
				},
			}

			return result, nil
		}

		result.result.Kind = &taskletv2.ProcessingResult_ServerError{
			ServerError: &taskletv2.ServerError{
				Code:        taskletv2.ErrorCodes_ERROR_CODE_CRASHED,
				Description: fmt.Sprintf("User job error: %+v", subprocessErr),
			},
		}
		return result, nil
	}

	// NB: Exit code == 0, probe for output or user error

	// Probe for user error
	errorFileStat, statErr := os.Stat(env.P.ErrorFilePath())
	if statErr == nil && errorFileStat.Size() > 0 {
		ctxlog.Infof(ctx, logger, "Processing user error file. Size: %v", errorFileStat.Size())
		userError := &taskletv2.UserError{}
		_, err := readSerializedProto(ctx, env.P.ErrorFilePath(), userError)
		if err != nil {
			ctxlog.Infof(ctx, logger, "Failed to parse user error file: %v", err)
			result.result.Kind = &taskletv2.ProcessingResult_ServerError{
				ServerError: &taskletv2.ServerError{
					Code:        taskletv2.ErrorCodes_ERROR_CODE_BAD_OUTPUT,
					Description: fmt.Sprintf("Failed to parse error file: %v", err),
				},
			}
		} else {
			ctxlog.Info(ctx, logger, "User error parsed")
			result.result.Kind = &taskletv2.ProcessingResult_UserError{
				UserError: userError,
			}
		}
	} else {
		// NB: handle output
		ctxlog.Info(ctx, logger, "Processing output file")
		serializedOutput, outputReadError := readSerializedProto(ctx, env.P.OutputFilePath(), info.outputMessage)
		if outputReadError != nil {
			ctxlog.Info(ctx, logger, "Bad output", log.Error(outputReadError))
			result.result.Kind = &taskletv2.ProcessingResult_ServerError{
				ServerError: &taskletv2.ServerError{
					Code:        taskletv2.ErrorCodes_ERROR_CODE_BAD_OUTPUT,
					Description: fmt.Sprintf("output file processing error: %v", outputReadError),
				},
			}
			return result, nil
		} else {
			ctxlog.Infof(ctx, logger, "Finished read of tasklet result (%d bytes)", len(serializedOutput))
			result.result.Kind = &taskletv2.ProcessingResult_Output{
				Output: &taskletv2.ExecutionOutput{SerializedOutput: serializedOutput},
			}
		}
	}

	return result, nil
}

func reportExecutionResult(
	ctx context.Context,
	client privatetaskletv1.InternalServiceClient,
	execution *taskletv2.Execution,
	stats *taskletv2.ExecutionStats,
	processingResult *taskletv2.ProcessingResult,
	logger log.Logger,
) error {

	request := &privatetaskletv1.ReportExecutionStatusRequest{
		Id:               execution.GetMeta().Id,
		Stats:            stats,
		ProcessingResult: processingResult,
	}
	logger.Infof("Reporting result to tasklet service:\n%s", prototext.Format(request))
	resp, err := client.ReportExecutionStatus(ctx, request)
	if err == nil {
		logger.Infof("Result reported. Service response: %q", prototext.Format(resp))
	}
	return err
}

func buildAndRunAndReportTasklet(
	ctx context.Context,
	taskletEnv *TaskletEnvironment,
	internalAPIClient privatetaskletv1.InternalServiceClient,
	info *executionInfo,
	logger log.Logger,
) error {
	outcome, executionError := runTaskletSubprocess(
		ctx,
		taskletEnv,
		info,
		logger,
	)

	if executionError != nil {
		ctxlog.Error(ctx, logger, "Execution error", log.Error(executionError))
		outcome = &ExecutionOutcome{
			result: &taskletv2.ProcessingResult{},
		}
		outcome.result.Kind = &taskletv2.ProcessingResult_ServerError{
			ServerError: &taskletv2.ServerError{
				Code:        taskletv2.ErrorCodes_ERROR_CODE_GENERIC,
				Description: fmt.Sprintf("ERROR: %+v", executionError),
			},
		}
		outcome.stats = &taskletv2.ExecutionStats{
			FinishedAt: timestamppb.Now(),
		}
	}
	if outcome.result.Kind == nil {
		outcome.result.Kind = &taskletv2.ProcessingResult_ServerError{
			ServerError: &taskletv2.ServerError{
				Code:        taskletv2.ErrorCodes_ERROR_CODE_GENERIC,
				Description: "no data",
			},
		}
	}
	if err := reportExecutionResult(
		ctx,
		internalAPIClient,
		info.execution,
		outcome.stats,
		outcome.result,
		logger,
	); err != nil {
		return xerrors.Errorf("Failed to report execution result: %w", err)
	}

	if executionError != nil {
		return xerrors.Errorf("Unexpected error: %w", executionError)
	}
	return nil
}
