package cmd

import (
	"context"
	"fmt"
	"net/url"
	"os"
	"path"
	"path/filepath"
	"syscall"
	"time"

	"github.com/gofrs/flock"
	"github.com/spf13/cobra"
	"go.uber.org/zap"
	"go.uber.org/zap/zapcore"

	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/config"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/ilog"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/server"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/utils"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/watcher"
	"a.yandex-team.ru/library/go/core/log/zap/logrotate"
)

var (
	cfgFile                string
	serviceEndpoint        string
	logFile                string
	lockPath               string
	enableNvml             bool
	enablePersistenced     bool
	enableFabricmanager    string
	enableDCGM             string
	enablePeriodicHungTest string
	enableNvidiaUvm        string
	enablePeerMem          bool
	enableVfio             string
	forceVfioInit          bool
	enableIbMetrics        bool
	allocIbDevs            bool
	allocIncludeIbUverbs   bool
	allocIncludeRoceUverbs bool
	allocNvgpuUnixSocket   bool
	enableSocketActivation bool
	enableLimitsFixup      bool
	debug                  bool
	mockMode               string
	secure                 bool
	setupLayers            bool
	cleanLayers            bool
	dcgmSocketAddr         string
	dcgmIsUnixSocket       bool
)

const (
	cudaSrcPath   = "/usr/lib/x86_64-linux-gnu/libcuda.so"
	cudaLayerBase = "/opt/nvgpu-manager2/layers/libcuda-dir"
	cudaLayerLink = "/opt/nvgpu-manager2/layers/libcuda"
)

var rootCmd = &cobra.Command{
	Use:   "nvgpu-manager",
	Short: "NVIDIA GPU manager service",
	Long:  "NVIDIA GPU manager service, Docs https://doc.yandex-team.ru/nvgpu_manager",
	RunE: func(cmd *cobra.Command, args []string) error {
		return doMain()
	},
}

// Variable used to get around initialization loop
var rootCmd2 *cobra.Command

func Execute() {
	if err := rootCmd.Execute(); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
}

func init() {
	//cobra.OnInitialize(initConfig)
	rootCmd.PersistentFlags().StringVar(&cfgFile, "config", config.DefaultServerConf, "config file")
	rootCmd.PersistentFlags().StringVar(&serviceEndpoint, "service", config.DefaultServerAddress, "gprc service endpoint")
	rootCmd.PersistentFlags().StringVar(&logFile, "log-file", config.DefaultServerLog, "log file")
	rootCmd.PersistentFlags().StringVar(&lockPath, "lock-file", config.DefaultServerLock, "server flock file")
	rootCmd.Flags().BoolVar(&enableNvml, "enable-nvml", config.DefaultNvmlMode, "enable nvml mode")
	rootCmd.Flags().BoolVar(&enablePersistenced, "enable-persistenced", config.DefaultNvidiaPersistenced, "enable nvidia-persistenced service")
	rootCmd.Flags().StringVar(&enableFabricmanager, "enable-fabricmanager", string(config.DefaultNvidiaFabricmanager), "enable nvidia-fabricmanager service")
	rootCmd.Flags().StringVar(&enableDCGM, "enable-dcgm", string(config.DefaultNvidiaDCGM), "enable nvidia-dcgm (Data Center GPU Manager)")
	rootCmd.Flags().StringVar(&enablePeriodicHungTest, "enable-periodic-hung-test", string(config.DefaultPeriodicHungTest), "enable periodic hung test")
	rootCmd.Flags().StringVar(&enableNvidiaUvm, "enable-nvidia-uvm", string(config.DefaultNvidiaUvm), "enable nvidia_uvm")
	rootCmd.Flags().BoolVar(&enablePeerMem, "enable-peer-mem", config.DefaultPeerMem, "enable nv_peer_mem")
	rootCmd.Flags().StringVar(&enableVfio, "enable-vfio", string(config.DefaultVfioMode), "enable vfio mode")
	rootCmd.Flags().BoolVar(&forceVfioInit, "force-vfio-init", config.DefaultForceVfioInit, "force init gpus to vfio mode")
	rootCmd.Flags().BoolVar(&enableIbMetrics, "enable-ib-metrics", config.DefaultIbMetrics, "enable collecting infiniband metrics (and sending to Yasm)")
	rootCmd.Flags().BoolVar(&allocIbDevs, "alloc-ib-devs", config.DefaultAllocIbDevs, "add infiniband devices to Alloc cmd response")
	rootCmd.Flags().BoolVar(&allocIncludeIbUverbs, "alloc-include-ib-uverbs", config.DefaultAllocIncludeIbUverbs, "add InfiniBand uverbs to Alloc cmd response")
	rootCmd.Flags().BoolVar(&allocIncludeRoceUverbs, "alloc-include-roce-uverbs", config.DefaultAllocIncludeRoceUverbs, "add RoCE uverbs to Alloc cmd response")
	rootCmd.Flags().BoolVar(&allocNvgpuUnixSocket, "alloc-nvgpu-unix-socket", config.DefaultAllocNvgpuUnixSocket, "add nvgpumanager unix socket to bind section of Alloc cmd response")
	rootCmd.Flags().BoolVar(&enableSocketActivation, "enable-socket-activation", config.DefaultEnableSocketActivation, "enable systemd socket activation for nvgpumanager")
	rootCmd.Flags().BoolVar(&enableLimitsFixup, "enable-limits-fixup", config.DefaultEnableLimitsFixup, "turn on fixing up gpu freqs, power, etc.")
	rootCmd.Flags().StringVar(&mockMode, "mock-mode", config.DefaultMockMode, "set mock device mode, for testing only")
	rootCmd.Flags().BoolVar(&secure, "secure", config.DefaultSecureMode, "set permissions for endpoint socket")
	rootCmd.Flags().BoolVar(&debug, "debug", false, "enable debug log level")
	rootCmd.Flags().BoolVar(&setupLayers, "setup-layers", false, "setup layers and exit")
	rootCmd.Flags().BoolVar(&cleanLayers, "clean-layers", false, "cleanup layers and exit")
	rootCmd.Flags().StringVar(&dcgmSocketAddr, "dcgm-socket-addr", config.DefaultDCGMSocketAddr, "address of dcgm api socket")
	rootCmd.Flags().BoolVar(&dcgmIsUnixSocket, "dcgm-unix-socket", config.DefaultDCGMIsUnixSocket, "set the dcgm-socket-addr unix mode")
	rootCmd2 = rootCmd
}

// Code is taken from infra/hostctl/rpc/helpers.go
const (
	/*
		Inherited from original ya-salt:
		 * need ms to be able to "profile"
		 * no need for TZ - we have MSK on all hosts
	*/
	timeLayout = "2006-01-02 15:04:05.000"
)

func timeFormat(t time.Time, enc zapcore.PrimitiveArrayEncoder) {
	// Code is taken from zapcore/encoder.go - found no other way.
	type appendTimeEncoder interface {
		AppendTimeLayout(time.Time, string)
	}

	if enc, ok := enc.(appendTimeEncoder); ok {
		enc.AppendTimeLayout(t, timeLayout)
		return
	}
	enc.AppendString(t.Format(timeLayout))
}

func doInitLog(fname string, l zapcore.Level) (*zap.Logger, error) {
	absPath, err := filepath.Abs(logFile)
	if err != nil {
		return nil, err
	}
	u, err := url.ParseRequestURI(absPath)
	if err != nil {
		return nil, err
	}
	sink, err := logrotate.NewLogrotateSink(u, syscall.SIGHUP)
	if err != nil {
		return nil, err
	}

	encConf := zap.NewProductionEncoderConfig()
	encConf.EncodeLevel = zapcore.CapitalLevelEncoder
	encConf.EncodeTime = timeFormat
	encConf.ConsoleSeparator = " - "
	fileEnc := zapcore.NewConsoleEncoder(encConf)

	consoleEnc := zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig())
	al := zap.NewAtomicLevelAt(l)

	core := zapcore.NewTee(
		zapcore.NewCore(fileEnc, sink, al),
		zapcore.NewCore(consoleEnc, zapcore.Lock(os.Stdout), al))

	logger := zap.New(core)
	logger.Info("Init logger, log should be reopened by SIGHUP")
	_ = zap.ReplaceGlobals(logger)
	ilog.SetLogger(logger)
	return logger, nil
}

func NewNVGPUManager(log *zap.Logger, c *config.Configuration, options ...config.Option) (*server.NVGPUManager, error) {
	config.ApplyOptions(c, options...)
	config.TurnOffDisabledModeOptions(c)
	return server.NewNVGPUManager(log, c)
}

func doMain() error {
	fileLock := flock.New(lockPath)
	locked, err := fileLock.TryLock()
	defer fileLock.Unlock()
	if !locked || err != nil {
		if err == nil {
			err = os.NewSyscallError("flock", syscall.EWOULDBLOCK)
		}
		fmt.Printf("error on getting lock %s, err:%v\n", lockPath, err)
		os.Exit(1)
	}

	lvl := zap.InfoLevel
	if debug {
		lvl = zap.DebugLevel
	}
	ll, err := doInitLog(logFile, lvl)
	if err != nil {
		panic("Fail to init logger, err: " + err.Error())
	}
	defer ll.Sync()
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	// Handle pseudo flags first
	if setupLayers {
		return doSetupLayers(ctx, ll)
	}
	if cleanLayers {
		return doCleanupLayers(ctx, ll)
	}

	ll.Info("Starting OS watcher")
	sigs := watcher.NewOSWatcher(syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)

	var serverOpts config.Configuration
	s, err := NewNVGPUManager(
		ll,
		&serverOpts,
		config.ApplyDefaultConfig(),
		config.ApplyConfFile(ll, cfgFile),
		config.CmdlineNvmlMode(enableNvml, rootCmd2.Flags().Changed("enable-nvml")),
		config.CmdlineNvidiaPersistenced(enablePersistenced, rootCmd2.Flags().Changed("enable-persistenced")),
		config.CmdlineNvidiaFabricmanager(enableFabricmanager, rootCmd2.Flags().Changed("enable-fabricmanager"), ll),
		config.CmdlineNvidiaDCGM(enableDCGM, rootCmd2.Flags().Changed("enable-dcgm"), ll),
		config.CmdlinePeriodicHungTest(enablePeriodicHungTest, rootCmd2.Flags().Changed("enable-periodic-hung-test"), ll),
		config.CmdlineNvidiaUvm(enableNvidiaUvm, rootCmd2.Flags().Changed("enable-nvidia-uvm"), ll),
		config.CmdlinePeerMem(enablePeerMem, rootCmd2.Flags().Changed("enable-peer-mem")),
		config.CmdlineVfioMode(enableVfio, rootCmd2.Flags().Changed("enable-vfio"), ll),
		config.CmdlineForceVfioInit(forceVfioInit, rootCmd2.Flags().Changed("force-vfio-init")),
		config.CmdlineIbMetrics(enableIbMetrics, rootCmd2.Flags().Changed("enable-ib-metrics")),
		config.CmdlineAllocIbDevs(allocIbDevs, rootCmd2.Flags().Changed("alloc-ib-devs")),
		config.CmdlineAllocIncludeIbUverbs(allocIncludeIbUverbs, rootCmd2.Flags().Changed("alloc-include-ib-uverbs")),
		config.CmdlineAllocIncludeRoceUverbs(allocIncludeRoceUverbs, rootCmd2.Flags().Changed("alloc-include-roce-uverbs")),
		config.CmdlineAllocNvgpuUnixSocket(allocNvgpuUnixSocket, rootCmd2.Flags().Changed("alloc-nvgpu-unix-socket")),
		config.CmdlineEnableSocketActivation(enableSocketActivation, rootCmd2.Flags().Changed("enable-socket-activation")),
		config.CmdlineEnableLimitsFixup(enableLimitsFixup, rootCmd2.Flags().Changed("enable-limits-fixup")),
		config.CmdlineMockMode(mockMode, rootCmd2.Flags().Changed("mock-mode")),
		config.CmdlineSecure(secure, rootCmd2.Flags().Changed("secure")),
		config.CmdlineService(serviceEndpoint, rootCmd2.Flags().Changed("service")),
	)
	if err != nil {
		ll.Fatal("Fail to create NVGPUManager", zap.Error(err))
		os.Exit(1)
	}

	go func() {
		<-sigs
		s.Stop("stopping by signal")
	}()

	if err := s.Serve(ctx); err != nil {
		ll.Fatal("failed to serve", zap.Error(err))
		os.Exit(1)
	}
	return nil
}

func doSetupLayers(ctx context.Context, log *zap.Logger) error {
	log.Info("setupLayers")
	err := doCleanupLayers(ctx, log)
	if err != nil {
		return err
	}
	err = os.MkdirAll(path.Join(cudaLayerBase, "lib64"), 0755)
	if err != nil {
		return fmt.Errorf("fail to create layer dir,  %w", err)
	}
	err = utils.CopyFile(cudaSrcPath, path.Join(cudaLayerBase, "lib64/libcuda.so"))
	if err != nil {
		return fmt.Errorf("fail to copy libcuda, %w", err)
	}
	return os.Symlink(cudaLayerBase, cudaLayerLink)
}

func doCleanupLayers(ctx context.Context, log *zap.Logger) error {
	log.Info("cleanupLayers")
	if _, err := os.Stat(cudaLayerLink); err == nil {
		err = os.Remove(cudaLayerLink)
		if err != nil {
			return err
		}
	}
	if _, err := os.Stat(cudaLayerBase); err == nil {
		err = os.RemoveAll(cudaLayerBase)
		if err != nil {
			return err
		}
	}
	return nil
}
