package device

import (
	"fmt"
	"path"
	"regexp"
	"strings"
	"sync"
	"time"

	"a.yandex-team.ru/infra/rsm/nvgpumanager/vendor/github.com/NVIDIA/go-nvml/pkg/nvml"
	"go.uber.org/zap"

	pb "a.yandex-team.ru/infra/rsm/nvgpumanager/api"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/config"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/ilog"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/utils"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/pkg/yasm"
)

var (
	nvmlLib = NvmlLib{}

	fixupGpuParamsOnInit sync.Once

	NVToYpPrefixMap = map[string]string{
		"A100-SXM-80GB": "gpu_tesla_a100",
	}
)

type NvmlDevice struct {
	Ready         pb.Condition
	Index         uint
	PciDev        *PciDevice
	Device        NvmlAPIDevice
	Status        *NvmlAPIDeviceStatus
	DcgmStatus    *DcgmValues
	DriverVersion string
	CudaMajor     uint32
	CudaMinor     uint32
	TotalEcc      yasm.Int64TSS
	Stuck         pb.Condition
}

func processNvDeviceModelName(name string) (string, error) {
	re, err := regexp.Compile(`(.+) MIG ([0-9]+)g\.([0-9]+)gb`)
	if err != nil {
		return "", err
	}

	match := re.FindSubmatch([]byte(name))

	if len(match) == 0 {
		return "", nil
	}

	prefix, ok := NVToYpPrefixMap[string(match[1])]
	if !ok {
		return "", nil
	}

	return fmt.Sprintf("%s_mig_%sc%sg", prefix, string(match[2]), string(match[3])), nil
}

func (d *NvmlDevice) ProtoMarshal() *pb.GpuDevice {
	var ps []*pb.ProcessInfo
	for _, p := range d.Status.Processes {
		ps = append(ps, &pb.ProcessInfo{Pid: uint32(p.Pid), MemoryUsedMb: p.UsedGpuMemory})
	}
	throttle := pb.Condition{}
	if d.Status.isThrottled() {
		utils.SetCondition(&throttle, true, d.Status.Throttle.String())
	}

	id := d.Device.GetUUID()
	if !d.Device.IsMigDevice() {
		id = d.PciDev.UUID
	}

	busID := d.PciDev.BusID
	if d.Device.IsMigDevice() {
		busID = fmt.Sprintf("%s.%d", d.PciDev.BusID, d.Device.GetMinor())
	}

	smUtilization := float32(-1)
	smOccupancy := float32(-1)

	if d.DcgmStatus != nil {
		smUtilization = float32(d.DcgmStatus.SmUtilization * 100)
		smOccupancy = float32(d.DcgmStatus.SmOccupancy * 100)
	}

	repl := pb.GpuDevice{
		Meta: &pb.PciDeviceMeta{
			Id:    id,
			BusId: busID,
		},
		Spec: &pb.GpuDeviceSpec{
			PciDevice: d.PciDev.ProtoMarshal(),
			Driver: &pb.GpuDeviceSpec_Nvidia{
				Nvidia: &pb.NvGpuSpec{
					Uuid:          d.Device.GetUUID(),
					DevicePath:    d.Device.GetUniquePath(),
					Model:         d.Device.GetModel(),
					Power:         d.Device.GetPower(),
					MemorySizeMb:  d.Device.GetMemory(),
					NumaNode:      int32(d.Device.GetCPUAffinity()),
					DriverVersion: d.DriverVersion,
					CudaVersion:   &pb.VersionInfo{Major: d.CudaMajor, Minor: d.CudaMinor},
				},
			},
		},
		Status: &pb.GpuDeviceStatus{
			Ready: &d.Ready,
			Driver: &pb.GpuDeviceStatus_Nvidia{
				Nvidia: &pb.NvGpuStatus{
					Power:             d.Status.Power,
					Temperature:       d.Status.Temperature,
					MemoryUsedMb:      d.Status.Memory.Global.Used,
					MemoryFreeMb:      d.Status.Memory.Global.Free,
					Processes:         ps,
					Throttle:          &throttle,
					PciThroughputRxMb: d.Status.PCI.Throughput.RX,
					PciThroughputTxMb: d.Status.PCI.Throughput.TX,
					GpuUtilization:    d.Status.Utilization.GPU,
					MemoryUtilization: d.Status.Utilization.Memory,
					SmUtilization:     smUtilization,
					SmOccupancy:       smOccupancy,
				},
			},
		},
	}

	// change Spec.PciDevice.ModelName to device name from driver is needed
	name, err := processNvDeviceModelName(d.Device.GetModel())
	if err != nil {
		ilog.Log().Error("Error on nvDevice model name processing: " + err.Error())
	}

	if name != "" {
		repl.Spec.PciDevice.ModelName = name
	}

	// change Spec.PciDevice.MemorySizeGb to value from driver
	repl.Spec.PciDevice.MemorySizeGb = uint32(d.Device.GetMemory()) >> 10

	return &repl
}
func getValue64(ptr *uint64) uint64 {
	if ptr != nil {
		return uint64(*ptr)
	} else {
		return 0
	}
}

func getValue(ptr *uint) uint {
	if ptr != nil {
		return *ptr
	} else {
		return 0
	}
}

func (d *NvmlDevice) YasmMetrics() yasm.YasmMetrics {
	gpuPath := path.Base(d.Device.GetUniquePath()) // TODO: Check

	m := yasm.YasmMetrics{
		Tags: map[string]string{
			"itype":      "runtimecloud",
			"gpu_model":  d.PciDev.ModelName,
			"gpu_path":   gpuPath,
			"gpu_driver": "nvidia",
		},
		TTL: 30,
		Values: []yasm.YasmValue{
			yasm.YasmValue{
				Name:  "gpustat-device_model_count_tmmv",
				Value: 1,
			},
			yasm.YasmValue{
				Name:  "gpustat-device_ready_tmmv",
				Value: d.Ready.Status,
			},
			yasm.YasmValue{
				Name:  "gpustat-utilization_gpu_tvvv",
				Value: d.Status.Utilization.GPU,
			},
			yasm.YasmValue{
				Name:  "gpustat-utilization_memory_tvvv",
				Value: d.Status.Utilization.Memory,
			},
			yasm.YasmValue{
				Name:  "gpustat-utilization_encoder_tmmv",
				Value: d.Status.Utilization.Encoder,
			},
			yasm.YasmValue{
				Name:  "gpustat-utilization_decoder_tmmv",
				Value: d.Status.Utilization.Decoder,
			},
			yasm.YasmValue{
				Name:  "gpustat-gpu_temperature_tvvv",
				Value: d.Status.Temperature,
			},
			yasm.YasmValue{
				Name:  "gpustat-gpu_cpu_clock_tvvv",
				Value: d.Status.Clocks.Cores,
			},
			yasm.YasmValue{
				Name:  "gpustat-gpu_memory_clock_tvvv",
				Value: d.Status.Clocks.Memory,
			},
			yasm.YasmValue{
				Name:  "gpustat-gpu_pci_rx_tmmv",
				Value: d.Status.PCI.Throughput.RX,
			},
			yasm.YasmValue{
				Name:  "gpustat-gpu_pci_tx_tmmv",
				Value: d.Status.PCI.Throughput.TX,
			},
			yasm.YasmValue{
				Name:  "gpustat-gpu_power_limit_tmmv",
				Value: d.Device.GetPower(),
			},
			yasm.YasmValue{
				Name:  "gpustat-gpu_power_usage_tmmv",
				Value: d.Status.Power,
			},
			yasm.YasmValue{
				Name:  "gpustat-ecc_total_tmmv",
				Value: d.TotalEcc.DiffPS(),
			},
			yasm.YasmValue{
				Name:  "gpustat-throttled_tmmv",
				Value: d.Status.isThrottled(),
			},
			yasm.YasmValue{
				Name:  "gpustat-gpu_stuck_count_tmmv",
				Value: d.isStuck(),
			},
		},
	}

	return m
}

func (s *NvmlAPIDeviceStatus) isThrottled() bool {
	if s.Throttle != nvml.ClocksThrottleReasonNone && s.Throttle != nvml.ClocksThrottleReasonGpuIdle {
		return true
	}

	return false
}

func (d *NvmlDevice) isStuck() bool {
	nowTS := time.Now()
	stuckTS := d.Stuck.LastTransitionTime.AsTime()
	if d.Stuck.Status && nowTS.Sub(stuckTS).Minutes() >= config.StuckGPUTime {
		return true
	}

	return false
}

func totalEcc(s *NvmlAPIDeviceStatus) int64 {
	eccL1Errors := s.Memory.ECCErrors.L1Cache
	eccL2Errors := s.Memory.ECCErrors.L2Cache
	eccDeviceErrors := s.Memory.ECCErrors.DeviceMemory
	return int64(eccL1Errors + eccL2Errors + eccDeviceErrors)
}

// TransferStatus status from old object generation
func (d *NvmlDevice) TransferStatus(orig *NvmlDevice) {
	d.TotalEcc = orig.TotalEcc
	d.TotalEcc.Update(totalEcc(d.Status))
	// If new state is Ready, we should inherent old state
	if d.Ready.Status {
		utils.CopyCondition(&d.Ready, &orig.Ready)
	}
	if orig.Stuck.Status && d.Stuck.Status {
		utils.CopyCondition(&d.Stuck, &orig.Stuck)
	}
}

func NewNvmlDevices(api NvmlInterface, dcgmAPI DcgmInterface, pcache map[string]*PciDevice, conf *config.Configuration) (goodDevs []*NvmlDevice, badDevsCnt int, err error) {
	ll := ilog.Log()
	cnt, err := api.GetDeviceCount()
	if err != nil {
		return nil, 0, err
	}
	driverVersion, err := api.GetDriverVersion()
	if err != nil {
		ll.Error("fail to fetch driver version", zap.Error(err))
	}
	cudaMajor, cudaMinor, err := api.GetCudaDriverVersion()
	if err != nil {
		ll.Error("fail to fetch cuda driver version", zap.Error(err))
	}

	parentDevices := []*NvmlDevice{}

	currentIndex := 0

	newNvmlDevice := func(index int, device NvmlAPIDevice) (*NvmlDevice, error) {
		st, err := api.DeviceStatus(device)
		ll.Debug("nvml.Device.Status()", zap.Any("status", st), zap.Error(err))
		if err != nil {
			err = fmt.Errorf("skipping device with uuid=%s, path=%s: %w", device.GetUUID(), device.GetUniquePath(), err)
			ll.Error("nvml.DeviceStatus() failed", zap.Error(err))
			return nil, err
		}

		var dcgmSt *DcgmValues = nil

		if dcgmAPI != nil {
			dcgmSt, err = dcgmAPI.GetDeviceValues(device)
			if err != nil {
				err = fmt.Errorf("skipping device with uuid=%s, path=%s: %w", device.GetUUID(), device.GetUniquePath(), err)
				ll.Error("dcgm.DeviceStatus() failed", zap.Error(err))
				return nil, err
			}
		}

		id := strings.ToLower(device.GetPCIBusID()[4:])
		pdev, ok := pcache[id]
		if !ok {
			err = fmt.Errorf("not such pci bus=%s for nvml.uuid=%s: %w", id, device.GetUUID(), ErrNoent)
			ll.Error("pci lookup failed", zap.Error(err))
			return nil, err
		}

		nvd := NvmlDevice{
			Index:         uint(index),
			PciDev:        pdev,
			Device:        device,
			Status:        st,
			DcgmStatus:    dcgmSt,
			DriverVersion: driverVersion,
			CudaMajor:     uint32(cudaMajor),
			CudaMinor:     uint32(cudaMinor),
		}
		if nvd.PciDev.MemoryGb == 0 {
			utils.SetCondition(&nvd.Ready, false, "Unknown device")
		} else {
			utils.SetCondition(&nvd.Ready, true, "")
		}
		nvd.TotalEcc.Init(totalEcc(st))

		return &nvd, nil
	}

	if dcgmAPI != nil {
		err := dcgmAPI.UpdateValues()
		if err != nil {
			err = fmt.Errorf("can't update dcgm values: %w", err)
			ll.Error("nvml.NewDevice() failed", zap.Error(err))
			return nil, 0, err
		}
	}

	for i := 0; i < cnt; i++ {
		d, err := api.NewDevice(i)
		ll.Debug("nvml.NewDevice()", zap.Int("idx", i), zap.Any("dev", d), zap.Error(err))
		if err != nil {
			badDevsCnt++
			err = fmt.Errorf("skipping device with idx=%d: %w", i, err)
			ll.Error("nvml.NewDevice() failed", zap.Error(err))
			continue
		}

		devicesToProcess := []NvmlAPIDevice{}

		pdev, err := newNvmlDevice(i, d)
		if err != nil {
			badDevsCnt++
			continue
		}
		parentDevices = append(parentDevices, pdev)

		if d.IsMigEnabled() {
			for _, dev := range d.GetMigDevices() {
				devicesToProcess = append(devicesToProcess, dev)
			}
		} else {
			devicesToProcess = append(devicesToProcess, d)
		}

		for _, device := range devicesToProcess {
			nvd, err := newNvmlDevice(currentIndex, device)
			if err != nil {
				badDevsCnt++
				continue
			}
			currentIndex++
			goodDevs = append(goodDevs, nvd)
		}
	}

	ilog.Log().Debug("newNvmlDevices return", zap.Any("goodDevs", goodDevs))

	// RESMAN-66: limit Graphics freq on a100_80g
	// RESMAN-70: limit GPU power on a100_80g (set freq to default 1410MHz)
	if conf.EnableLimitsFixup {
		fixupGpuParams(parentDevices)
	}

	for _, d := range goodDevs {
		isStuck := d.Status.Utilization.GPU == config.StuckGPUUtil &&
			d.Status.Power < config.StuckGPUPower &&
			d.Status.Temperature < config.StuckGPUTemp

		if isStuck {
			utils.SetCondition(&d.Stuck, true, fmt.Sprintf("GPU is probably stuck because: gpu_util=%d, power<%d, temp<%d",
				config.StuckGPUUtil, config.StuckGPUPower, config.StuckGPUTemp))
		}
	}

	return goodDevs, badDevsCnt, nil
}

func fixupGpuParams(devs []*NvmlDevice) {
	ll := ilog.Log()

	fixupGpuParamsOnInit.Do(func() {
		for _, dev := range devs {
			if dev.PciDev.ModelName != "gpu_tesla_a100_80g" { // pciids.nvidiaDevices["20b2"].Name
				continue
			}

			err := fixupGpuFreq(dev.Device, config.A100_80FreqMin, config.A100_80FreqMax)
			if err != nil {
				ll.Error("'fixupGpuFreq()' failed", zap.Error(err))
			}

			err = fixupGpuPower(dev.Device, config.A100_80Power)
			if err != nil {
				ll.Error("'fixupGpuPower()' failed", zap.Error(err))
			}
		}
	})

	for _, dev := range devs {
		if dev.PciDev.ModelName != "gpu_tesla_a100_80g" { // pciids.nvidiaDevices["20b2"].Name
			continue
		}

		// it seems that Auto Boost on new GPUs is automatically disabled after manually setting the frequencies
		/*
			err = dev.Device.turnOffAutoBoost()
			if err != nil {
				ll.Error("'turnOffAutoBoost()' failed", zap.Error(err))
			}
		*/

		err := fixupGpuFreqIfNeeded(dev.Device, config.A100_80FreqMin, config.A100_80FreqMax)
		if err != nil {
			ll.Error("'fixupGpuFreqIfNeeded()' failed", zap.Error(err))
		}

		err = fixupGpuPowerIfNeeded(dev.Device, config.A100_80Power)
		if err != nil {
			ll.Error("'fixupGpuPowerIfNeeded()' failed", zap.Error(err))
		}
	}
}

func fixupGpuFreq(device NvmlAPIDevice, minFreq uint32, maxFreq uint32) error {
	var err error

	ret := device.GetDevice().SetGpuLockedClocks(minFreq, maxFreq)
	if ret != nvml.SUCCESS {
		err = fmt.Errorf("nvml.SetGpuLockedClocks() failed, err: %v", nvml.ErrorString(ret))
		return err
	} else {
		ilog.Log().Info("gpu clocks were changed to", zap.Any("min", minFreq), zap.Any("max", maxFreq), zap.Any("gpu", device))
		// TODO: maybe update dev.Status.Clocks.Cores
	}

	return err
}
func fixupGpuFreqIfNeeded(device NvmlAPIDevice, minFreq uint32, maxFreq uint32) error {
	var err error

	curFreq, ret := device.GetDevice().GetClockInfo(nvml.CLOCK_SM)
	if ret != nvml.SUCCESS {
		err = fmt.Errorf("nvml.GetClockInfo() failed, err: %v", nvml.ErrorString(ret))
		return err
	}
	if minFreq <= curFreq && curFreq <= maxFreq {
		return nil
	}

	ilog.Log().Info("current gpu clocks differ from expected", zap.Uint32("current freq", curFreq), zap.Uint32("expected min freq", minFreq),
		zap.Uint32("expected max freq", maxFreq), zap.Any("gpu", device))
	// TODO: check if this freq is supported (supported according to mem freq too)
	return fixupGpuFreq(device, minFreq, maxFreq)
}

func fixupGpuPower(device NvmlAPIDevice, targetPower uint32) error {
	var err error

	ret := device.GetDevice().SetPowerManagementLimit(targetPower)
	if ret != nvml.SUCCESS {
		err = fmt.Errorf("nvml.SetPowerManagementLimit() failed, err: %v", nvml.ErrorString(ret))
		return err
	} else {
		ilog.Log().Info("gpu power was changed to", zap.Any("power", targetPower), zap.Any("gpu", device))
		// TODO: maybe update dev.Device.Power
	}

	return err
}

func fixupGpuPowerIfNeeded(device NvmlAPIDevice, targetPower uint32) error {
	var err error

	curPower, ret := device.GetDevice().GetPowerManagementLimit()
	if ret != nvml.SUCCESS {
		err = fmt.Errorf("nvml.GetPowerManagementLimit() failed, err: %v", nvml.ErrorString(ret))
		return err
	}
	if curPower == targetPower {
		return nil
	}

	ilog.Log().Info("current gpu power differs from expected", zap.Uint32("current power", curPower), zap.Uint32("expected power", targetPower), zap.Any("gpu", device))
	// TODO: check if this power is supported
	return fixupGpuPower(device, targetPower)
}

func turnOffAutoBoost(device NvmlAPIDevice) error {
	var err error

	enabled, defEnabled, ret := device.GetDevice().GetAutoBoostedClocksEnabled()
	if ret != nvml.SUCCESS {
		err = fmt.Errorf("nvml.GetAutoBoostedClocksEnabled() failed, err: %v", nvml.ErrorString(ret))
		return err
	}

	if enabled == nvml.FEATURE_ENABLED {
		ret := device.GetDevice().SetAutoBoostedClocksEnabled(nvml.FEATURE_DISABLED)
		if ret != nvml.SUCCESS {
			err = fmt.Errorf("nvml.SetAutoBoostedClocksEnabled(nvml.FEATURE_DISABLED) failed, err: %v", nvml.ErrorString(ret))
			return err
		} else {
			ilog.Log().Info("AutoBoosted clocks disabled for", zap.Any("gpu", device))
		}
	}

	if defEnabled == nvml.FEATURE_ENABLED {
		ret := device.GetDevice().SetDefaultAutoBoostedClocksEnabled(nvml.FEATURE_DISABLED, uint32(0))
		if ret != nvml.SUCCESS {
			err = fmt.Errorf("nvml.SetDefaultAutoBoostedClocksEnabled(nvml.FEATURE_DISABLED) failed, err: %v", nvml.ErrorString(ret))
			return err
		} else {
			ilog.Log().Info("Default AutoBoosted clocks disabled for", zap.Any("gpu", device))
		}
	}

	return err
}
