package config

import (
	"fmt"
	"os"
	"strings"

	"go.uber.org/zap"
	"gopkg.in/yaml.v2"
)

type Tristate string

const (
	False    Tristate = "false"
	True     Tristate = "true"
	Optional Tristate = "optional"
)

func validateTristate(t Tristate) error {
	switch t {
	case False, True, Optional:
		return nil
	}

	return fmt.Errorf("invalid tristate value: %s", t)
}

func isTristateValid(t Tristate) bool {
	return validateTristate(t) == nil
}

func (t *Tristate) UnmarshalYAML(unmarshal func(interface{}) error) error {
	var s string

	err := unmarshal(&s)
	if err != nil {
		return err
	}

	tri := Tristate(strings.ToLower(s))

	err = validateTristate(tri)
	if err == nil {
		*t = tri
	}

	return err
}

type Configuration struct {
	NvmlMode               bool
	NvidiaPersistenced     bool
	NvidiaFabricmanager    Tristate
	NvidiaDCGM             Tristate
	PeriodicHungTest       Tristate
	NvidiaUvm              Tristate
	PeerMem                bool
	IbMetrics              bool
	AllocIbDevs            bool
	AllocIncludeIbUverbs   bool
	AllocIncludeRoceUverbs bool
	AllocNvgpuUnixSocket   bool
	EnableSocketActivation bool
	EnableLimitsFixup      bool
	VfioMode               Tristate
	ForceVfioInit          bool
	MockDevices            string
	Secure                 bool
	GrpcEndpoint           string
	LogFile                string
	CUDARoot               string
	DCGMSocketAddr         string
	DCGMIsUnixSocket       bool
}

type NvmlModeConf struct {
	NvmlEnable             *bool     `yaml:"nvml_enable"`
	PersistencedEnable     *bool     `yaml:"persistenced_enable"`
	FabricmanagerEnable    *Tristate `yaml:"fabricmanager_enable"`
	DCGMEnable             *Tristate `yaml:"dcgm_enable"`
	PeriodicHungTestEnable *Tristate `yaml:"periodic_hung_test_enable"`
	NvidiaUvmEnable        *Tristate `yaml:"nvidia_uvm_enable"`
	PeerMemEnable          *bool     `yaml:"peer_mem_enable"`
	IbMetricsEnable        *bool     `yaml:"ib_metrics_enable"`
	AllocIbDevs            *bool     `yaml:"alloc_ib_devs"`
	AllocIncludeIbUverbs   *bool     `yaml:"alloc_include_ib_uverbs"`
	AllocIncludeRoceUverbs *bool     `yaml:"alloc_include_roce_uverbs"`
	AllocNvgpuUnixSocket   *bool     `yaml:"alloc_nvgpu_unix_socket"`
	EnableSocketActivation *bool     `yaml:"enable_socket_activation"`
	EnableLimitsFixup      *bool     `yaml:"enable_limits_fixup"`
}

type VfioModeConf struct {
	VfioEnable    *Tristate `yaml:"vfio_enable"`
	ForceVfioInit *bool     `yaml:"force_vfio_init"`
}

type ConfFile struct {
	NvmlMode NvmlModeConf `yaml:"nvml_mode"`
	VfioMode VfioModeConf `yaml:"vfio_mode"`
}

type Option func(c *Configuration)

func ApplyDefaultConfig() Option {
	return func(c *Configuration) {
		c.NvmlMode = DefaultNvmlMode
		c.NvidiaPersistenced = DefaultNvidiaPersistenced
		c.NvidiaFabricmanager = DefaultNvidiaFabricmanager
		c.NvidiaDCGM = DefaultNvidiaDCGM
		c.PeriodicHungTest = DefaultPeriodicHungTest
		c.NvidiaUvm = DefaultNvidiaUvm
		c.PeerMem = DefaultPeerMem
		c.IbMetrics = DefaultIbMetrics
		c.AllocIbDevs = DefaultAllocIbDevs
		c.AllocIncludeIbUverbs = DefaultAllocIncludeIbUverbs
		c.AllocIncludeRoceUverbs = DefaultAllocIncludeRoceUverbs
		c.AllocNvgpuUnixSocket = DefaultAllocNvgpuUnixSocket
		c.EnableSocketActivation = DefaultEnableSocketActivation
		c.EnableLimitsFixup = DefaultEnableLimitsFixup
		c.VfioMode = DefaultVfioMode
		c.ForceVfioInit = DefaultForceVfioInit
		c.MockDevices = DefaultMockMode
		c.Secure = DefaultSecureMode
		c.GrpcEndpoint = DefaultServerAddress
		c.LogFile = DefaultServerLog
		c.CUDARoot = DefaultCUDALayer
		c.DCGMSocketAddr = DefaultDCGMSocketAddr
		c.DCGMIsUnixSocket = DefaultDCGMIsUnixSocket
	}
}

func parseConfFile(log *zap.Logger, path string) (*ConfFile, error) {
	log.Info("parseConfFile", zap.String("path", path))
	var parsedConfig ConfFile
	f, err := os.Open(path)
	if err != nil {
		err = fmt.Errorf("failed to open config file at: %s, err: %w", path, err)
		log.Error("parseConfFile()", zap.Error(err))
		return nil, err
	}
	defer f.Close()

	dec := yaml.NewDecoder(f)
	dec.SetStrict(true)
	if err := dec.Decode(&parsedConfig); err != nil {
		err = fmt.Errorf("failed to parse config: %w", err)
		log.Error("parseConfFile()", zap.Error(err))
		return nil, err
	}

	return &parsedConfig, nil
}

func ApplyConfFile(log *zap.Logger, path string) Option {
	parsedConfig, err := parseConfFile(log, path)

	return func(c *Configuration) {
		if err == nil {
			if parsedConfig.NvmlMode.NvmlEnable != nil {
				c.NvmlMode = *parsedConfig.NvmlMode.NvmlEnable
			}
			if parsedConfig.NvmlMode.PersistencedEnable != nil {
				c.NvidiaPersistenced = *parsedConfig.NvmlMode.PersistencedEnable
			}
			if parsedConfig.NvmlMode.FabricmanagerEnable != nil {
				c.NvidiaFabricmanager = *parsedConfig.NvmlMode.FabricmanagerEnable
			}
			if parsedConfig.NvmlMode.DCGMEnable != nil {
				c.NvidiaDCGM = *parsedConfig.NvmlMode.DCGMEnable
			}
			if parsedConfig.NvmlMode.PeriodicHungTestEnable != nil {
				c.PeriodicHungTest = *parsedConfig.NvmlMode.PeriodicHungTestEnable
			}
			if parsedConfig.NvmlMode.NvidiaUvmEnable != nil {
				c.NvidiaUvm = *parsedConfig.NvmlMode.NvidiaUvmEnable
			}
			if parsedConfig.NvmlMode.PeerMemEnable != nil {
				c.PeerMem = *parsedConfig.NvmlMode.PeerMemEnable
			}
			if parsedConfig.NvmlMode.IbMetricsEnable != nil {
				c.IbMetrics = *parsedConfig.NvmlMode.IbMetricsEnable
			}
			if parsedConfig.NvmlMode.AllocIbDevs != nil {
				c.AllocIbDevs = *parsedConfig.NvmlMode.AllocIbDevs
			}
			if parsedConfig.NvmlMode.AllocIncludeIbUverbs != nil {
				c.AllocIncludeIbUverbs = *parsedConfig.NvmlMode.AllocIncludeIbUverbs
			}
			if parsedConfig.NvmlMode.AllocIncludeRoceUverbs != nil {
				c.AllocIncludeRoceUverbs = *parsedConfig.NvmlMode.AllocIncludeRoceUverbs
			}
			if parsedConfig.NvmlMode.AllocNvgpuUnixSocket != nil {
				c.AllocNvgpuUnixSocket = *parsedConfig.NvmlMode.AllocNvgpuUnixSocket
			}
			if parsedConfig.NvmlMode.EnableSocketActivation != nil {
				c.EnableSocketActivation = *parsedConfig.NvmlMode.EnableSocketActivation
			}
			if parsedConfig.NvmlMode.EnableLimitsFixup != nil {
				c.EnableLimitsFixup = *parsedConfig.NvmlMode.EnableLimitsFixup
			}
			if parsedConfig.VfioMode.VfioEnable != nil {
				c.VfioMode = *parsedConfig.VfioMode.VfioEnable
			}
			if parsedConfig.VfioMode.ForceVfioInit != nil {
				c.ForceVfioInit = *parsedConfig.VfioMode.ForceVfioInit
			}
		}
	}
}

func CmdlineNvmlMode(nvmlMode bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.NvmlMode = nvmlMode
		}
	}
}

func CmdlineNvidiaPersistenced(nvidiaPersistenced bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.NvidiaPersistenced = nvidiaPersistenced
		}
	}
}

func CmdlineNvidiaFabricmanager(nvidiaFabricmanager string, setOnCmdline bool, log *zap.Logger) Option {
	t := Tristate(nvidiaFabricmanager)
	err := validateTristate(t)
	if err != nil {
		log.Error("invalid fabricmanager_enable value passed on cmdline", zap.Error(err))
	}

	return func(c *Configuration) {
		if setOnCmdline && (err == nil) {
			c.NvidiaFabricmanager = t
		}
	}
}

func CmdlineNvidiaDCGM(nvidiaDCGM string, setOnCmdline bool, log *zap.Logger) Option {
	t := Tristate(nvidiaDCGM)
	err := validateTristate(t)
	if err != nil {
		log.Error("invalid dcgm_enable value passed on cmdline", zap.Error(err))
	}

	return func(c *Configuration) {
		if setOnCmdline && (err == nil) {
			c.NvidiaDCGM = t
		}
	}
}

func CmdlinePeriodicHungTest(periodicHungTest string, setOnCmdline bool, log *zap.Logger) Option {
	t := Tristate(periodicHungTest)
	err := validateTristate(t)
	if err != nil {
		log.Error("invalid periodic_hung_test_enable value passed on cmdline", zap.Error(err))
	}

	return func(c *Configuration) {
		if setOnCmdline && (err == nil) {
			c.PeriodicHungTest = t
		}
	}
}

func CmdlineNvidiaUvm(nvidiaUvm string, setOnCmdline bool, log *zap.Logger) Option {
	t := Tristate(nvidiaUvm)
	err := validateTristate(t)
	if err != nil {
		log.Error("invalid nvidia_uvm_enable value passed on cmdline", zap.Error(err))
	}

	return func(c *Configuration) {
		if setOnCmdline && (err == nil) {
			c.NvidiaUvm = t
		}
	}
}

func CmdlinePeerMem(peerMem bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.PeerMem = peerMem
		}
	}
}

func CmdlineIbMetrics(ibMetrics bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.IbMetrics = ibMetrics
		}
	}
}

func CmdlineAllocIbDevs(allocIbDevs bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.AllocIbDevs = allocIbDevs
		}
	}
}

func CmdlineAllocIncludeIbUverbs(allocIncludeIbUverbs bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.AllocIncludeIbUverbs = allocIncludeIbUverbs
		}
	}
}

func CmdlineAllocIncludeRoceUverbs(allocIncludeRoceUverbs bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.AllocIncludeRoceUverbs = allocIncludeRoceUverbs
		}
	}
}

func CmdlineAllocNvgpuUnixSocket(allocNvgpuUnixSocket bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.AllocNvgpuUnixSocket = allocNvgpuUnixSocket
		}
	}
}

func CmdlineEnableSocketActivation(enableSocketActivation bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.EnableSocketActivation = enableSocketActivation
		}
	}
}

func CmdlineEnableLimitsFixup(enableLimitsFixup bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.EnableLimitsFixup = enableLimitsFixup
		}
	}
}

func CmdlineVfioMode(vfioMode string, setOnCmdline bool, log *zap.Logger) Option {
	t := Tristate(vfioMode)
	err := validateTristate(t)
	if err != nil {
		log.Error("invalid enable-vfio value passed on cmdline", zap.Error(err))
	}

	return func(c *Configuration) {
		if setOnCmdline && (err == nil) {
			c.VfioMode = t
		}
	}
}

func CmdlineForceVfioInit(forceVfioInit bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.ForceVfioInit = forceVfioInit
		}
	}
}

func CmdlineMockMode(mockMode string, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.MockDevices = mockMode
		}
	}
}

func CmdlineSecure(secure bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.Secure = secure
		}
	}
}

func CmdlineService(serviceEndpoint string, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.GrpcEndpoint = serviceEndpoint
		}
	}
}

func CmdlineDCGMSocketAddr(socketAddr string, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.DCGMSocketAddr = socketAddr
		}
	}
}

func CmdlineDCGMIsUnixSocket(isUnixSocket bool, setOnCmdline bool) Option {
	return func(c *Configuration) {
		if setOnCmdline {
			c.DCGMIsUnixSocket = isUnixSocket
		}
	}
}

func ApplyOptions(c *Configuration, options ...Option) {
	for _, option := range options {
		option(c)
	}
	/* TODO fixup default values here */
}

func TurnOffDisabledModeOptions(c *Configuration) {
	if !c.NvmlMode {
		c.NvidiaPersistenced = false
		c.NvidiaFabricmanager = False
		c.NvidiaDCGM = False
		c.PeriodicHungTest = False
		c.NvidiaUvm = False
		c.PeerMem = false
		// TODO: turn off metrics too?
		c.IbMetrics = false
		c.AllocIbDevs = false
		c.AllocIncludeIbUverbs = false
		c.AllocIncludeRoceUverbs = false
		c.AllocNvgpuUnixSocket = false
		c.EnableSocketActivation = false
		c.EnableLimitsFixup = false
	}

	if c.VfioMode == False {
		c.ForceVfioInit = false
	}
}
