package executor

import (
	"context"
	"net"
	"os"
	"os/exec"
	"path"
	"path/filepath"
	"time"

	"google.golang.org/protobuf/encoding/prototext"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/types/dynamicpb"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/tasklet/api/v2"
	"a.yandex-team.ru/tasklet/experimental/internal/consts"
	"golang.org/x/sys/unix"
)

const (
	taskletContextEnvName    = "TASKLET_CONTEXT"
	taskletSandboxPath       = "sandbox"
	resourcesDataPath        = "service_resources"
	logsPath                 = "tasklet_logs"
	sandboxResourcesDataPath = "sandbox_resources"
)

const (
	inputFileName  = "input.bin"
	outputFileName = "output.bin"
	errorFileName  = "error.bin"
	taskletStdout  = "tasklet.stdout"
	taskletStderr  = "tasklet.stderr"
)

type Paths struct {
	Root             string
	TaskletSandbox   string
	Logs             string
	Resources        string
	SandboxResources string
}

func NewPaths(cwd string) Paths {
	return Paths{
		Root:             cwd,
		TaskletSandbox:   filepath.Join(cwd, taskletSandboxPath),
		Logs:             filepath.Join(cwd, logsPath),
		Resources:        filepath.Join(cwd, resourcesDataPath),
		SandboxResources: filepath.Join(cwd, sandboxResourcesDataPath),
	}
}

func (p Paths) GetContextPath() string {
	return path.Join(p.TaskletSandbox, "tasklet_context.message")
}

// Build creates required directories
// Assumption: All folders reside inside Root catalog
func (p Paths) Build() error {
	if st, err := os.Stat(p.Root); err != nil {
		return err
	} else if !st.IsDir() {
		return xerrors.Errorf("Bad root folder. Path: %q, Mode: %q", p.Root, st.Mode().String())
	}
	// TODO: check subfolder is inside workdir?
	for _, dir := range []string{p.TaskletSandbox, p.Logs, p.Resources, p.SandboxResources} {
		if err := os.MkdirAll(dir, 0755); err != nil {
			return xerrors.Errorf("Can not create folder. Path: %q, Err: %w", dir, err)
		}
	}
	return nil
}

func (p Paths) InputFilePath() string {
	return filepath.Join(p.TaskletSandbox, inputFileName)

}

func (p Paths) OutputFilePath() string {
	return filepath.Join(p.TaskletSandbox, outputFileName)
}

func (p Paths) ErrorFilePath() string {
	return filepath.Join(p.TaskletSandbox, errorFileName)
}

func (p Paths) StdoutFilePath() string {
	return filepath.Join(p.Logs, taskletStdout)

}

func (p Paths) StderrFilePath() string {
	return filepath.Join(p.Logs, taskletStderr)
}

type executionInfo struct {
	execution     *taskletv2.Execution
	build         *taskletv2.Build
	inputMessage  *dynamicpb.Message
	outputMessage *dynamicpb.Message
}

func (e *executionInfo) getSerializedInput() []byte {
	return e.execution.GetSpec().GetInput().GetSerializedData()
}

func (e *executionInfo) getBuildSpec() *taskletv2.BuildSpec {
	return e.build.GetSpec()
}

type TaskletEnvironment struct {
	l           log.Logger
	TaskletPath string
	EnvVars     map[string]string
	JavaBinPath string
	Ref         *taskletv2.EnvironmentRef
	P           Paths

	ctxProvider *ContextProvider
}

func NewTaskletEnvironment(logger log.Logger, ctxProvider *ContextProvider) *TaskletEnvironment {
	return &TaskletEnvironment{
		l:           logger,
		Ref:         &taskletv2.EnvironmentRef{},
		ctxProvider: ctxProvider,
	}
}

func (te *TaskletEnvironment) buildPaths(cwd string) error {
	te.l.Info("Building paths")
	te.P = NewPaths(cwd)
	return te.P.Build()
}

func (te *TaskletEnvironment) prepareEnvVars(contextPath string) {
	var envVarsToCopy = []string{"USER", "HOME", "TEMP", "TMP"}

	te.EnvVars = make(map[string]string)
	te.EnvVars[taskletContextEnvName] = contextPath
	for _, varName := range envVarsToCopy {
		te.EnvVars[varName] = os.Getenv(varName)
	}
}

func (te *TaskletEnvironment) prepareEnvironment(
	taskletPath string,
	javaBinary string,
	info *executionInfo,
	cwd string,
) error {

	if err := te.buildPaths(cwd); err != nil {
		return err
	}
	te.TaskletPath = taskletPath
	te.prepareEnvVars(te.P.GetContextPath())

	te.l.Debugf("Execution:\n%s", prototext.Format(info.execution))
	te.l.Debugf("Build:\n%s", prototext.Format(info.build))

	te.l.Info("Discovering java")
	if err := discoverJavaEnvironment(javaBinary, info.getBuildSpec(), te, te.l); err != nil {
		return err
	}

	te.l.Info("Discovering arc_client")
	if err := discoverArcClientEnvironment(info.getBuildSpec(), info.execution, te, te.l); err != nil {
		return err
	}

	te.l.Info("Discovering resource manager")
	if err := discoverSandboxResourceManagerEnvironment(info.getBuildSpec(), info.execution, te, te.l); err != nil {
		return err
	}
	te.ctxProvider.SetEnvironmentRef(te.Ref)

	return nil
}

func (te *TaskletEnvironment) dumpContextFile() error {

	if contextFile, err := os.Create(te.P.GetContextPath()); err != nil {
		return xerrors.NewSentinel("failed to create context file").Wrap(err)
	} else if contextData, err := proto.Marshal(te.ctxProvider.GetContext()); err != nil {
		return err
	} else if n, err := contextFile.Write(contextData); err != nil {
		return err
	} else if n != len(contextData) {
		return xerrors.Errorf("serialized %v bytes of context instead of %v", n, len(contextData))
	} else {
		te.l.Infof("Context dumped to %q", te.P.GetContextPath())
	}
	return nil
}

func getSandboxResource(
	execution *taskletv2.Execution, resourceType consts.SandboxResourceType, logger log.Logger,
) (string, error) {
	resourceMap := execution.GetStatus().GetResources()
	if !resourceMap.GetResolved() {
		return "", xerrors.New("Resources are not resolved in execution document")
	}
	resourceInfo := resourceMap.GetResources()[resourceType.String()]
	resourceID := resourceInfo.GetResourceId()
	if resourceID == 0 {
		return "", xerrors.Errorf("Missing %v resource id in resolve cache", resourceType.String())
	}
	support, err := newSandboxSupportAPIClient(logger)
	if err != nil {
		return "", err
	}
	sandboxResourceManagerPath, err := support.GetResource(context.TODO(), resourceID)
	if err != nil {
		return "", err
	}
	return sandboxResourceManagerPath, nil
}

func discoverArcClientEnvironment(
	buildSpec *taskletv2.BuildSpec,
	execution *taskletv2.Execution,
	te *TaskletEnvironment,
	logger log.Logger,
) error {
	if !buildSpec.GetEnvironment().GetArcClient().GetEnabled() {
		logger.Debug("Arc client in not requested")
		return nil
	}

	arcArchivePath, err := getSandboxResource(execution, consts.ArcClientResourceType, logger)
	if err != nil {
		return err
	}

	logger.Infof(
		"Unpacking arc client to %q",
		filepath.Join(te.P.Resources, consts.ArcClientResourceType.String()),
	)
	if err := UnpackTgz(
		arcArchivePath,
		filepath.Join(te.P.Resources, consts.ArcClientResourceType.String()),
	); err != nil {
		return xerrors.NewSentinel("failed to unpack arc resource").Wrap(err)
	}
	arcClientPath := filepath.Join(te.P.Resources, consts.ArcClientResourceType.String(), "arc")
	if stat, err := os.Stat(arcClientPath); err != nil {
		return xerrors.NewSentinel("cat not stat() arc client").Wrap(err)
	} else if !stat.Mode().IsRegular() || stat.Mode().Perm()&(unix.S_IXUSR|unix.S_IXGRP|unix.S_IXOTH) == 0 {
		return xerrors.Errorf("Bad arc client perms: %v", stat.Mode().String())
	}
	te.Ref.ArcClient = &taskletv2.ArcClientRef{
		Enabled: true,
		Path:    arcClientPath,
	}
	return nil
}

const (
	sandboxResourceManagerAddress = "[::1]:50000" // TODO: use normal port
)

func waitTCPPort(timeout time.Duration, address string) error {
	tl := time.NewTimer(timeout)
	defer tl.Stop()

	tick := time.NewTicker(time.Millisecond * 250)
	defer tick.Stop()

	tcpCheck := func() error {
		conn, err := net.Dial("tcp", address)
		if err != nil {
			return err
		}
		return conn.Close()
	}

Check:
	for {
		select {
		case <-tl.C:
			return xerrors.Errorf("Tcp check timeout on address %v", address)
		case <-tick.C:
			if tcpCheck() == nil {
				break Check
			}
		}
	}
	return nil
}

func discoverSandboxResourceManagerEnvironment(
	buildSpec *taskletv2.BuildSpec,
	execution *taskletv2.Execution,
	te *TaskletEnvironment,
	logger log.Logger,
) error {
	if !buildSpec.GetEnvironment().GetSandboxResourceManager().GetEnabled() {
		logger.Debug("Sandbox resource manager in not requested")
		return nil
	}

	sandboxResourceManagerPath, err := getSandboxResource(execution, consts.SandboxResourceManagerType, logger)
	if err != nil {
		return err
	}

	var runtime string
	if isSandbox() {
		runtime = "sandbox"
	} else {
		runtime = "yt"
	}

	subprocess := exec.Command(
		sandboxResourceManagerPath, "--runtime", runtime, "--address", sandboxResourceManagerAddress,
		"--iteration", "0",
	)
	logger.Infof("Starting subprocess %v", subprocess)
	subprocess.Env = os.Environ()
	subprocess.Dir = te.P.SandboxResources
	rmErr := subprocess.Start()
	if rmErr != nil {
		return rmErr
	}

	portErr := waitTCPPort(time.Second*10, sandboxResourceManagerAddress)
	if portErr != nil {
		return portErr
	}

	te.Ref.SandboxResourceManager = &taskletv2.SandboxResourceManagerRef{
		Enabled: true,
		Address: sandboxResourceManagerAddress,
	}
	return nil
}

func discoverJavaEnvironment(
	javaBinary string,
	buildSpec *taskletv2.BuildSpec,
	te *TaskletEnvironment,
	logger log.Logger,
) error {
	javaRef := taskletv2.JavaEnvironmentRef{
		Jdk11: &taskletv2.JdkRef{},
		Jdk17: &taskletv2.JdkRef{},
	}
	te.JavaBinPath = ""
	te.Ref.Java = &javaRef

	launchType := buildSpec.GetLaunchSpec().GetType()
	javaReq := buildSpec.GetEnvironment().GetJava()

	jdks := []struct {
		name string
		req  *taskletv2.JdkReq
		ref  *taskletv2.JdkRef
	}{
		{name: consts.LaunchTypeJava11, req: javaReq.GetJdk11(), ref: javaRef.Jdk11},
		{name: consts.LaunchTypeJava17, req: javaReq.GetJdk17(), ref: javaRef.Jdk17},
	}

	for _, jdk := range jdks {
		if jdk.name == launchType || jdk.req.GetEnabled() {
			jdk.ref.Enabled = true

			javaBinPath := "/opt/" + jdk.name + "/bin/java"
			if _, err := os.Stat(javaBinPath); err == nil {
				logger.Infof("Discovered jdk: %q", javaBinPath)
				jdk.ref.Path = javaBinPath
			} else if javaBinary != "" {
				// Java binary path is provided from outside of executor.
				// It should be used in tests only. So, use it for each JDK.
				logger.Infof("Overriding java bin by command line argument: %q\n", javaBinary)
				jdk.ref.Path = javaBinary
			} else {
				return xerrors.Errorf("%q was not found", jdk.name)
			}

			if jdk.name == launchType {
				te.JavaBinPath = jdk.ref.Path
			}
		}
	}

	return nil
}
