package main

import (
	"fmt"
	"os"
	"os/exec"
	"path/filepath"

	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/utils"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/pkg/modprobe"
)

func main() {
	err := modprobe.LoadModuleIfUnloaded("nvidia")
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to load 'nvidia' kernel module, err: %v\n", err)
		os.Exit(1)
	}

	libPath, err := utils.GetNvidiaLibraryPath()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Fail to get lib path: %s\n", err)
		os.Exit(1)
	}
	binPath, err := utils.GetNvidiaBinPath(filepath.Base(os.Args[0]))
	if err != nil {
		fmt.Fprintf(os.Stderr, "Fail to get binary path: %s\n", err)
		os.Exit(1)
	}

	// $ nvidia-smi -> /usr/bin/nvidia-smi -> /opt/yandex-nvidia-utils/bin/nvidia-smi ->
	// -> /opt/yandex-nvidia-utils/bin/nvidia-utils-wrapper ->
	// -> (wrapper gets current version of nvidia libs, e.g. 450.119.04) ->
	// -> /opt/yandex-nvidia-utils/450.119.04/nvidia-smi
	cmd := exec.Command(binPath)
	cmd.Args = append([]string{binPath}, os.Args[1:]...)
	cmd.Stdout = os.Stdout
	cmd.Stderr = os.Stderr
	cmd.Env = append(os.Environ(), "LD_LIBRARY_PATH="+libPath)

	if err := cmd.Start(); err != nil {
		fmt.Fprintf(os.Stderr, "Exec failed: %v\n", err)
		os.Exit(1)
	}
	if err := cmd.Wait(); err != nil {
		if exitErr, ok := err.(*exec.ExitError); ok {
			os.Exit(exitErr.ExitCode())
		}
		os.Exit(1)
	}
}
