package device

import (
	"fmt"
	"unsafe"

	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/config"

	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/ilog"
	"go.uber.org/zap"

	"github.com/NVIDIA/go-dcgm/pkg/dcgm"
)

var (
	fieldsArray = []dcgm.Short{
		dcgm.DCGM_FI_PROF_SM_ACTIVE,
		dcgm.DCGM_FI_PROF_SM_OCCUPANCY,
	}

	uuidFieldArray = []dcgm.Short{
		dcgm.DCGM_FI_DEV_UUID,
	}
)

type DcgmValues struct {
	SmUtilization float64
	SmOccupancy   float64
}

func (status *DcgmValues) InsertValue(val dcgm.FieldValue_v1) error {
	processedVal := processValue(val)

	var err error

	switch val.FieldId {

	case dcgm.DCGM_FI_PROF_SM_ACTIVE:
		status.SmUtilization, err = processedVal.Float64()

	case dcgm.DCGM_FI_PROF_SM_OCCUPANCY:
		status.SmOccupancy, err = processedVal.Float64()

	default:
		err = fmt.Errorf("dcgm.FieldValue with id %d is not supporter by DcgmValues", val.FieldId)
		ilog.Log().Error("generate dcgm status problem", zap.Error(err))
		return err
	}

	if err != nil {
		err = fmt.Errorf("can't process value with id %d, error: %v", val.FieldId, err)
		ilog.Log().Error("process value problem", zap.Error(err))
		return err
	}

	return nil
}

type DcgmInterface interface {
	Init(c *config.Configuration) error
	Shutdown() error
	UpdateDevices() error
	UpdateValues() error
	GetDeviceValues(dev NvmlAPIDevice) (*DcgmValues, error)
}

type DcgmLib struct {
	cleanLib func()

	uuidFieldGroup       dcgm.FieldHandle
	uuidFieldCreatedFlag bool

	fieldGroup       dcgm.FieldHandle
	fieldCreatedFlag bool

	deviceGroup      dcgm.GroupHandle
	groupCreatedFlag bool

	devsMap        map[string]dcgm.GroupEntityPair
	devsProfEnable map[string]bool
}

type DcgmFieldInfo struct {
	Value interface{}
	Error error
}

func (info DcgmFieldInfo) String() (string, error) {
	if info.Error != nil {
		return "", info.Error
	}
	if str, ok := info.Value.(string); ok {
		return str, nil
	} else {
		return "", fmt.Errorf("you try use value \"%v\" as type string", info.Value)
	}
}

func (info DcgmFieldInfo) Float64() (float64, error) {
	if info.Error != nil {
		return 0, info.Error
	}
	if val, ok := info.Value.(float64); ok {
		return val, nil
	} else {
		return 0, fmt.Errorf("you try use \"%v\" as type float64", info.Value)
	}
}

type DcgmGpuFieldsInfo struct {
	Fields map[uint]DcgmFieldInfo
	Error  error
}

func (dcgmLib *DcgmLib) Init(c *config.Configuration) error {
	unixMode := "0"
	if c.DCGMIsUnixSocket {
		unixMode = "1"
	}
	cleanFunc, err := dcgm.Init(dcgm.Standalone, c.DCGMSocketAddr, unixMode)
	dcgmLib.cleanLib = cleanFunc
	if err != nil {
		err = fmt.Errorf("nvidia dcgm.Init() failed, error: %v", err)
		ilog.Log().Error("dcgm.Init()", zap.Error(err))
		return err
	}

	dcgmLib.fieldCreatedFlag = false
	dcgmLib.uuidFieldCreatedFlag = false
	dcgmLib.groupCreatedFlag = false

	if ierr := dcgm.FieldsInit(); ierr != 0 {
		dcgmLib.cleanLib()
		err = fmt.Errorf("nvidia dcgm.FieldsInit() failed, error: %v", ierr)
		ilog.Log().Error("dcgm.FieldsInit()", zap.Error(err))
		return err
	}

	fh, err := dcgm.FieldGroupCreate("sm_util_fields", fieldsArray)
	if err != nil {
		_ = dcgm.FieldGroupDestroy(fh)
		dcgmLib.cleanLib()
		err = fmt.Errorf("nvidia dcgm.FieldGroupCreate() failed, error: %v", err)
		ilog.Log().Error("dcgm.FieldGroupCreate()", zap.Error(err))
		return err
	}

	dcgmLib.fieldGroup = fh
	dcgmLib.fieldCreatedFlag = true

	fh, err = dcgm.FieldGroupCreate("uuid_fields", uuidFieldArray)
	if err != nil {
		_ = dcgm.FieldGroupDestroy(fh)
		dcgmLib.cleanLib()
		err = fmt.Errorf("nvidia dcgm.FieldGroupCreate() failed, error: %v", err)
		ilog.Log().Error("dcgm.FieldGroupCreate()", zap.Error(err))
		return err
	}

	dcgmLib.uuidFieldGroup = fh
	dcgmLib.uuidFieldCreatedFlag = true

	return nil
}

func (dcgmLib *DcgmLib) Shutdown() error {
	if dcgmLib.fieldCreatedFlag {
		err := dcgm.FieldGroupDestroy(dcgmLib.fieldGroup)
		if err != nil {
			err = fmt.Errorf("nvidia dcgm.FieldGroupDestroy() failed, error: %v", err)
			ilog.Log().Error("dcgm.FieldGroupDestroy()", zap.Error(err))
			return err
		}
	}

	if dcgmLib.uuidFieldCreatedFlag {
		err := dcgm.FieldGroupDestroy(dcgmLib.uuidFieldGroup)
		if err != nil {
			err = fmt.Errorf("nvidia dcgm.FieldGroupDestroy() failed, error: %v", err)
			ilog.Log().Error("dcgm.FieldGroupDestroy()", zap.Error(err))
			return err
		}
	}

	if dcgmLib.groupCreatedFlag {
		err := dcgm.DestroyGroup(dcgmLib.deviceGroup)
		if err != nil {
			err = fmt.Errorf("nvidia dcgm.DestroyGroup() failed, error: %v", err)
			ilog.Log().Error("dcgm.DestroyGroup()", zap.Error(err))
			return err
		}
	}

	if ierr := dcgm.FieldsTerm(); ierr != 0 {
		err := fmt.Errorf("nvidia dcgm.FieldsTerm() failed, error: %v", ierr)
		ilog.Log().Error("dcgm.FieldsTerm()", zap.Error(err))
		return err
	}

	dcgmLib.cleanLib()
	return nil
}

func checkGpuProfileFieldsSupport(gpuid uint) (bool, error) {
	gh, err := dcgm.CreateGroup(fmt.Sprintf("check [gpu:%d] profile fields support", gpuid))
	if err != nil {
		return false, err
	}
	defer dcgm.DestroyGroup(gh)

	err = dcgm.AddToGroup(gh, gpuid)
	if err != nil {
		return false, err
	}

	_, err = dcgm.GetSupportedMetricGroups(uint(*(*uint64)(unsafe.Pointer(&gh))))
	if err != nil {
		if err.Error() == "Error getting supported metrics: This request is serviced by a module of DCGM that is not currently loaded" {
			return false, nil
		}
		return false, err
	}
	return true, nil
}

func getMigUUID(ent dcgm.MigEntityInfo) string {
	if ent.NvmlComputeInstanceId == 0xFFFFFFFF || ent.GpuUuid == "" {
		return ""
	}
	return fmt.Sprintf("MIG-%s/%d/%d", ent.GpuUuid, ent.NvmlInstanceId, ent.NvmlComputeInstanceId)
}

func (dcgmLib *DcgmLib) generateMigDevsMap() error {
	instanceInfo, err := dcgm.GetGpuInstanceHierarchy()
	if err != nil {
		err = fmt.Errorf("nvidia dcgm.GetGpuInstanceHierarchy() failed, error: %v", err)
		ilog.Log().Error("dcgm.GetGpuInstanceHierarchy()", zap.Error(err))
		return err
	}

	for _, info := range instanceInfo.EntityList {
		uuid := getMigUUID(info.Info)
		if uuid == "" {
			continue
		}
		dcgmLib.devsMap[uuid] = info.Entity
		parantProfEnable, ok := dcgmLib.devsProfEnable[info.Info.GpuUuid]
		if !ok {
			err = fmt.Errorf("can't find parent [%s] device for \"%s\" in devsProfEnable", info.Info.GpuUuid, uuid)
			ilog.Log().Error("DcgmLib.UpdateDevices", zap.Error(err))
			return err
		}

		if !parantProfEnable {
			ilog.Log().Info(fmt.Sprintf("\"%s\" gpu doesn't support profile metrics: dcgm status will be set to nil", uuid))
		}

		dcgmLib.devsProfEnable[uuid] = parantProfEnable
	}
	return nil
}

func (dcgmLib *DcgmLib) generateFullDevsMap() error {
	devList, err := dcgm.GetSupportedDevices()
	if err != nil {
		err = fmt.Errorf("nvidia dcgm.GetSupportedDevices() failed, error: %v", err)
		ilog.Log().Error("dcgm.GetSupportedDevices()", zap.Error(err))
		return err
	}

	tempGroup, err := dcgm.CreateGroup("get_uuid_group")
	if err != nil {
		err = fmt.Errorf("nvidia dcgm.CreateGroup() failed, error: %v", err)
		ilog.Log().Error("dcgm.CreateGroup()", zap.Error(err))
		return err
	}

	for _, dev := range devList {
		if err = dcgm.AddToGroup(tempGroup, dev); err != nil {
			err = fmt.Errorf("nvidia dcgm.AddToGroup() failed, error: %v", err)
			ilog.Log().Error("dcgm.AddToGroup()", zap.Error(err))
			return err
		}
	}

	err = dcgm.WatchFieldsWithGroup(dcgmLib.uuidFieldGroup, tempGroup)
	if err != nil {
		err = fmt.Errorf("nvidia dcgm.WatchFieldsWithGroup() failed, error: %v", err)
		ilog.Log().Error("dcgm.WatchFieldsWithGroup()", zap.Error(err))
		return err
	}

	for _, dev := range devList {

		values, err := dcgm.GetLatestValuesForFields(dev, uuidFieldArray)
		if err != nil {
			err = fmt.Errorf("nvidia dcgm.GetLatestValuesForFields() failed, error: %v", err)
			ilog.Log().Error("dcgm.GetLatestValuesForFields()", zap.Error(err))
			return err
		}

		uuid, err := processValue(values[0]).String()
		if err != nil {
			err = fmt.Errorf("cant get uuid value, error: %v", err)
			ilog.Log().Error("full gpu get uuid problem", zap.Error(err))
			return err
		}

		dcgmLib.devsMap[uuid] = dcgm.GroupEntityPair{EntityGroupId: dcgm.FE_GPU, EntityId: dev}
		profEnable, err := checkGpuProfileFieldsSupport(dev)
		if err != nil {
			err = fmt.Errorf("error on checking profile fields support for [gpu:%d]: %v", dev, err)
			ilog.Log().Error("DcgmLib.UpdateDevices", zap.Error(err))
			profEnable = false
		}

		if !profEnable {
			ilog.Log().Info(fmt.Sprintf("\"%s\" gpu doesn't support profile metrics: dcgm status will be set to nil", uuid))
		}

		dcgmLib.devsProfEnable[uuid] = profEnable
	}

	err = dcgm.DestroyGroup(tempGroup)
	if err != nil {
		err = fmt.Errorf("nvidia dcgm.DestroyGroup() failed, error: %v", err)
		ilog.Log().Error("dcgm.DestroyGroup()", zap.Error(err))
		return err
	}

	return nil
}

func (dcgmLib *DcgmLib) UpdateDevices() error {
	dcgmLib.devsMap = map[string]dcgm.GroupEntityPair{}
	dcgmLib.devsProfEnable = map[string]bool{}

	// Clean previos info

	var err error

	if dcgmLib.groupCreatedFlag {
		if err = dcgm.DestroyGroup(dcgmLib.deviceGroup); err != nil {
			err = fmt.Errorf("nvidia dcgm.DestroyGroup() failed, error: %v", err)
			ilog.Log().Error("dcgm.DestroyGroup()", zap.Error(err))
			return err
		}
		dcgmLib.deviceGroup = dcgm.GroupHandle{}
		dcgmLib.groupCreatedFlag = false
	}

	err = dcgmLib.generateFullDevsMap()
	if err != nil {
		return err
	}

	err = dcgmLib.generateMigDevsMap()
	if err != nil {
		return err
	}

	// create new divice group

	gh, err := dcgm.CreateGroup("all_nvgpumanager_devices")
	if err != nil {
		err = fmt.Errorf("nvidia dcgm.CreateGroup() failed, error: %v", err)
		ilog.Log().Error("dcgm.CreateGroup()", zap.Error(err))
		return err
	}

	var count uint = 0

	for uuid, entity := range dcgmLib.devsMap {
		if !dcgmLib.devsProfEnable[uuid] {
			continue
		}
		if err = dcgm.AddEntityToGroup(gh, entity.EntityGroupId, entity.EntityId); err != nil {
			err = fmt.Errorf("nvidia dcgm.AddEntityToGroup() failed, error: %v", err)
			ilog.Log().Error("dcgm.AddEntityToGroup()", zap.Error(err))
			return err
		}
		count += 1
	}

	dcgmLib.deviceGroup = gh
	dcgmLib.groupCreatedFlag = true

	if count == 0 {
		ilog.Log().Info("Host has not profile fields supported devices")
		return nil
	}

	err = dcgm.WatchFieldsWithGroup(dcgmLib.fieldGroup, dcgmLib.deviceGroup)
	if err != nil {
		err = fmt.Errorf("nvidia dcgm.WatchFieldsWithGroup() failed, error: %v", err)
		ilog.Log().Error("dcgm.WatchFieldsWithGroup()", zap.Error(err))
		return err
	}

	if err = dcgm.UpdateAllFields(); err != nil {
		err = fmt.Errorf("nvidia dcgm.UpdateAllFields() failed, error: %v", err)
		ilog.Log().Error("dcgm.UpdateAllFields()", zap.Error(err))
		return err
	}

	return nil
}

func (dcgmLib *DcgmLib) UpdateValues() error {
	if err := dcgm.UpdateAllFields(); err != nil {
		err = fmt.Errorf("nvidia dcgm.UpdateAllFields() failed, error: %v", err)
		ilog.Log().Error("dcgm.UpdateAllFields()", zap.Error(err))
		return err
	}
	return nil
}

func processValue(val dcgm.FieldValue_v1) DcgmFieldInfo {
	// validate value

	if val.Status != 0 {
		return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: return stauts is %d", val.FieldId, val.Status)}
	}

	// parse value

	var value interface{}

	switch val.FieldType {
	case dcgm.DCGM_FT_DOUBLE:
		value = val.Float64()
		switch value {
		case dcgm.DCGM_FT_FP64_BLANK:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is blank", val.FieldId)}
		case dcgm.DCGM_FT_FP64_NOT_FOUND:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is not found", val.FieldId)}
		case dcgm.DCGM_FT_FP64_NOT_SUPPORTED:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is not supported", val.FieldId)}
		case dcgm.DCGM_FT_FP64_NOT_PERMISSIONED:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is not permissioned", val.FieldId)}
		}

	case dcgm.DCGM_FT_INT64:
		value = val.Int64()
		switch value {
		case dcgm.DCGM_FT_INT64_BLANK:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is blank", val.FieldId)}
		case dcgm.DCGM_FT_INT64_NOT_FOUND:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is not found", val.FieldId)}
		case dcgm.DCGM_FT_INT64_NOT_SUPPORTED:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is not supported", val.FieldId)}
		case dcgm.DCGM_FT_INT64_NOT_PERMISSIONED:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is not permissioned", val.FieldId)}
		}

	case dcgm.DCGM_FT_STRING:
		value = val.String()
		switch value {
		case dcgm.DCGM_FT_STR_BLANK:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is blank", val.FieldId)}
		case dcgm.DCGM_FT_STR_NOT_FOUND:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is not found", val.FieldId)}
		case dcgm.DCGM_FT_STR_NOT_SUPPORTED:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is not supported", val.FieldId)}
		case dcgm.DCGM_FT_STR_NOT_PERMISSIONED:
			return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: value is not permissioned", val.FieldId)}
		}

	// case dcgm.DCGM_FT_TIMESTAMP:
	// 	value = "<<<timestamp>>>"
	// case dcgm.DCGM_FT_BINARY:
	// 	value = "<<<binary>>>"

	default:
		return DcgmFieldInfo{nil, fmt.Errorf("error on %d field: %d is not supported value type", val.FieldId, val.FieldType)}

	}

	return DcgmFieldInfo{value, nil}
}

func (dcgmLib *DcgmLib) GetDeviceValues(dev NvmlAPIDevice) (*DcgmValues, error) {

	if !dcgmLib.fieldCreatedFlag {
		return nil, fmt.Errorf("dcgm lib failed on init()")
	}

	var entity dcgm.GroupEntityPair

	if !dcgmLib.groupCreatedFlag {
		return nil, fmt.Errorf("dcgm lib failed on DeviceUpdate()")
	}

	entity, ok := dcgmLib.devsMap[dev.GetUUID()]
	if !ok {
		return nil, fmt.Errorf("can't get info for \"%s\" device: was not rigestred for data", dev.GetUUID())
	}

	profEnable, ok := dcgmLib.devsProfEnable[dev.GetUUID()]
	if !ok {
		return nil, fmt.Errorf("can't get info for \"%s\" device: was not rigestred for data enabling", dev.GetUUID())
	}

	if !profEnable {
		return nil, nil
	}

	values, err := dcgm.EntityGetLatestValues(entity.EntityGroupId, entity.EntityId, fieldsArray)
	if err != nil {
		err = fmt.Errorf("nvidia dcgm.EntityGetLatestValues() failed, error: %v", err)
		ilog.Log().Error("dcgm.EntityGetLatestValues()", zap.Error(err))
		return nil, err
	}

	result := DcgmValues{}

	for _, val := range values {
		err := result.InsertValue(val)
		if err != nil {
			return nil, err
		}
	}

	return &result, nil
}
