package launcher

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"os"
	"os/exec"
	"os/signal"
	"syscall"

	"a.yandex-team.ru/security/skotty/launcher/internal/logger"
	"a.yandex-team.ru/security/skotty/launcher/internal/updater"
	"a.yandex-team.ru/security/skotty/launcher/internal/version"
)

func Restart() error {
	var stdin, stdout bytes.Buffer
	exitCode, err := runSkotty(&stdin, &stdout, os.Stderr, "service", "restart")
	if err != nil {
		return err
	}

	if exitCode != 0 {
		return fmt.Errorf("'skotty service restart' exit with non-zero status code: %d", exitCode)
	}

	return nil
}

func Launch(args ...string) (int, error) {
	release, err := updater.ReadReleaseInfo()
	if err != nil || release.IsZero() {
		logger.Info("no release info - updating skotty")
		_, err := updater.Update()
		if err != nil {
			return -1, fmt.Errorf("update fail: %w", err)
		}
	}

	return runSkotty(os.Stdin, os.Stdout, os.Stderr, args...)
}

func runSkotty(stdin io.Reader, stdout, stderr io.Writer, args ...string) (int, error) {
	launcherExe, err := os.Executable()
	if err != nil {
		return -1, fmt.Errorf("can't determine self executable: %w", err)
	}

	release, err := updater.ReadReleaseInfo()
	if err != nil {
		return -1, fmt.Errorf("can't get skotty release: %w", err)
	}

	targetPath, err := updater.InReleasesPath(release.Version, updater.BinaryName)
	if err != nil {
		return -1, fmt.Errorf("can't determine target dir: %w", err)
	}

	if _, err := os.Stat(targetPath); os.IsNotExist(err) {
		logger.Infof("release not available - downloading skotty v%s", release.Version)
		targetPath, err = updater.DownloadVersion(release.Version)
		if err != nil {
			return -1, fmt.Errorf("update fail: %w", err)
		}
	}

	cmd := exec.Command(targetPath, args...)
	cmd.Stdin = stdin
	cmd.Stdout = stdout
	cmd.Stderr = stderr
	cmd.Env = append(
		os.Environ(),
		"UNDER_LAUNCHER=yes",
		"SKOTTY_LAUNCHER_PATH="+launcherExe,
		"SKOTTY_LAUNCHER_VERSION="+version.Full(),
		"SKOTTY_CHANNEL="+string(release.Channel),
	)

	if err := cmd.Start(); err != nil {
		return -1, fmt.Errorf("unable to start: %w", err)
	}

	signals := make(chan os.Signal, 1)
	signal.Notify(signals)

	go func() {
		for sig := range signals {
			_ = cmd.Process.Signal(sig)
		}
	}()

	exitCode := 0
	if err := cmd.Wait(); err != nil {
		var exitErr *exec.ExitError
		if errors.As(err, &exitErr) {
			// The program has exited with an exit code != 0
			if status, ok := exitErr.Sys().(syscall.WaitStatus); ok {
				exitCode = status.ExitStatus()
			}
		} else {
			return -1, fmt.Errorf("wait failed: %w", err)
		}
	}

	signal.Stop(signals)
	close(signals)

	return exitCode, nil
}
