package device

// based on:
// https://github.com/Mellanox/rdmamap/

import (
	"fmt"
	"io"
	"io/ioutil"
	"math"
	"os"
	"path"
	"path/filepath"
	"strconv"
	"strings"
	"time"

	"go.uber.org/zap"

	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/config"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/ilog"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/pkg/yasm"
)

const (
	RdmaClassDir = "/sys/class/infiniband"
	RdmaPortsdir = "ports"

	RdmaCountersDir   = "counters"
	RdmaHwCountersDir = "hw_counters"
)

// Returns a list of rdma device names
func getRdmaDeviceList() ([]string, error) {
	var rdmaDevices []string
	fd, err := os.Open(RdmaClassDir)
	if err != nil {
		return nil, err
	}
	defer fd.Close()

	fileInfos, err := fd.Readdir(-1)
	if err != nil {
		return nil, err
	}

	for i := range fileInfos {
		if fileInfos[i].IsDir() {
			continue
		}
		rdmaDevices = append(rdmaDevices, fileInfos[i].Name())
	}
	return rdmaDevices, nil
}

type RdmaStatEntry struct {
	Name  string
	Value uint64
}

type RdmaPortStats struct {
	HwStats []RdmaStatEntry /* /sys/class/infiniband/<dev>/<port>/hw_counters */
	Stats   []RdmaStatEntry /* /sys/class/infiniband/<dev>/<port>/counters */
	Port    int
}

type RdmaStats struct {
	PortStats []RdmaPortStats
}

func readCounter(name string) (uint64, error) {
	fd, err := os.OpenFile(name, os.O_RDONLY, 0444)
	if err != nil {
		return 0, err
	}
	defer fd.Close()

	if _, err = fd.Seek(0, io.SeekStart); err != nil {
		return 0, err
	}

	data, err := ioutil.ReadAll(fd)
	if err != nil {
		return 0, err
	}
	dataStr := string(data)
	dataStr = strings.Trim(dataStr, "\n")
	value, _ := strconv.ParseUint(dataStr, 10, 64)
	return value, nil
}

func getCountersFromDir(path string) ([]RdmaStatEntry, error) {
	var stats []RdmaStatEntry

	fd, err := os.Open(path)
	if err != nil {
		return stats, err
	}
	defer fd.Close()

	fileInfos, err := fd.Readdir(-1)
	if err != nil {
		return stats, err
	}

	for _, file := range fileInfos {
		if file.IsDir() {
			continue
		}
		value, err := readCounter(filepath.Join(path, file.Name()))
		if err != nil {
			return stats, err
		}
		entry := RdmaStatEntry{file.Name(), value}
		stats = append(stats, entry)
	}
	return stats, nil
}

// Get RDMA Sysfs stats from counters directory of a port of a rdma device
// Port number starts from 1.
func getRdmaSysfsStats(rdmaDevice string, port int) ([]RdmaStatEntry, error) {
	path := filepath.Join(RdmaClassDir, rdmaDevice,
		RdmaPortsdir, strconv.Itoa(port), RdmaCountersDir)

	rdmastats, err := getCountersFromDir(path)
	return rdmastats, err
}

// Get RDMA Sysfs stats from hw_counters directory of a port of a rdma device
// Port number starts from 1.
func getRdmaSysfsHwStats(rdmaDevice string, port int) ([]RdmaStatEntry, error) {
	path := filepath.Join(RdmaClassDir, rdmaDevice,
		RdmaPortsdir, strconv.Itoa(port), RdmaHwCountersDir)

	rdmastats, err := getCountersFromDir(path)
	return rdmastats, err
}

// Get RDMA sysfs starts from counter and hw_counters directory for a requested
// port of a device.
func getRdmaSysfsAllStats(rdmaDevice string, port int) (RdmaPortStats, error) {
	var portstats RdmaPortStats

	hwstats, err := getRdmaSysfsHwStats(rdmaDevice, port)
	if err != nil {
		return portstats, nil
	}
	portstats.HwStats = hwstats

	stats, err := getRdmaSysfsStats(rdmaDevice, port)
	if err != nil {
		return portstats, nil
	}
	portstats.Stats = stats
	portstats.Port = port
	return portstats, nil
}

// Get RDMA sysfs starts from counter and hw_counters directory for a
// rdma device.
func getRdmaSysfsAllPortsStats(rdmaDevice string) (RdmaStats, error) {
	var allstats RdmaStats

	path := filepath.Join(RdmaClassDir, rdmaDevice, RdmaPortsdir)
	fd, err := os.Open(path)
	if err != nil {
		return allstats, err
	}
	defer fd.Close()

	fileInfos, err := fd.Readdir(-1)
	if err != nil {
		return allstats, err
	}

	for i, file := range fileInfos {
		if fileInfos[i].Name() == "." || fileInfos[i].Name() == ".." {
			continue
		}
		if !file.IsDir() {
			continue
		}
		port, _ := strconv.Atoi(file.Name())
		portstats, err := getRdmaSysfsAllStats(rdmaDevice, port)
		if err != nil {
			return allstats, err
		}
		allstats.PortStats = append(allstats.PortStats, portstats)
	}
	return allstats, nil
}

var diffMetrics = map[string]bool{
	"port_rcv_data":     true,
	"port_xmit_data":    true,
	"port_rcv_packets":  true,
	"port_xmit_packets": true,
}

var dwordsMetrics = map[string]bool{
	"port_rcv_data":  true,
	"port_xmit_data": true,
}

var perSecMetrics = map[string]bool{
	"port_rcv_data_bytes":  true,
	"port_xmit_data_bytes": true,
	"port_rcv_packets":     true,
	"port_xmit_packets":    true,
}

func addBytesSuffix(name string) string {
	return name + "_bytes"
}

func addPerSecSuffix(name string) string {
	return name + "_per_sec"
}

func dwordsAsBytesValue(val uint64) uint64 {
	// not as concerned with overflows as 'val' is diff
	return 4 * val
}

func perSecondValue(prevTS time.Time, curTS time.Time) func(uint64) uint64 {
	return func(diff uint64) uint64 {
		return uint64(math.Round(float64(diff) / (curTS.Sub(prevTS).Seconds())))
	}
}

type postProcName func(string) string
type postProcValue func(uint64) uint64

func postProcMetrics(name string, val uint64, applicableMetrics map[string]bool, modName postProcName, modVal postProcValue) (string, uint64) {
	if _, found := applicableMetrics[name]; found {
		name = modName(name)
		val = modVal(val)
	}
	return name, val
}

type RdmaStatsRecord struct {
	TS    time.Time
	Stats RdmaStats
}

var devsStatsPrevRecords = map[string]RdmaStatsRecord{}

func GetIbYasmMetrics() ([]yasm.YasmMetrics, error) {
	ll := ilog.Log()

	devs, err := getRdmaDeviceList()
	if err != nil {
		ll.Error("'getRdmaDeviceList()' failed", zap.Error(err))
		return nil, err
	}
	ll.Debug("Found", zap.Int("numb", len(devs)), zap.Any("rdma devs", devs))

	var devsMetrics []yasm.YasmMetrics
	for _, dev := range devs {
		stats, err := getRdmaSysfsAllPortsStats(dev)
		if err != nil {
			ll.Error("'getRdmaSysfsAllPortsStats()' failed for", zap.String("dev", dev), zap.Error(err))
			continue
		}
		ll.Debug("get stats for", zap.String("dev", dev), zap.Any("stats", stats))
		curTS := time.Now()

		m := yasm.YasmMetrics{
			Tags: map[string]string{
				"itype":   "runtimecloud",
				"ib_path": dev,
			},
			TTL:    30,
			Values: []yasm.YasmValue{},
		}

		var val yasm.YasmValue
		for i, portStat := range stats.PortStats {
			for _, hwStat := range portStat.HwStats {
				val.Name = fmt.Sprintf("ibstat-port%d_hw_%s_tmmv", portStat.Port, hwStat.Name)
				val.Value = hwStat.Value
				m.Values = append(m.Values, val)
			}

			for j, stat := range portStat.Stats {
				statName := stat.Name
				if _, found := diffMetrics[stat.Name]; found {
					if prevRecord, exist := devsStatsPrevRecords[dev]; exist {
						prevStatVal := prevRecord.Stats.PortStats[i].Stats[j].Value
						if stat.Value < prevStatVal {
							// counter probably overflowed and it's better to
							// skip this iteration
							continue
						}
						statVal := stat.Value - prevStatVal

						statName, statVal = postProcMetrics(statName, statVal, dwordsMetrics, addBytesSuffix, dwordsAsBytesValue)
						statName, statVal = postProcMetrics(statName, statVal, perSecMetrics, addPerSecSuffix, perSecondValue(prevRecord.TS, curTS))

						val.Value = statVal
					} else {
						// this is probably first iteration and there is no 'prevVal',
						// so skip this iteration to avoid spikes on graphs
						continue
					}
				} else {
					val.Value = stat.Value
				}
				val.Name = fmt.Sprintf("ibstat-port%d_%s_tmmv", portStat.Port, statName)

				m.Values = append(m.Values, val)
			}
		}
		devsStatsPrevRecords[dev] = RdmaStatsRecord{curTS, stats}

		devsMetrics = append(devsMetrics, m)
		ll.Debug("parsed metrics for", zap.String("dev", dev), zap.Any("metrics", m))
	}

	return devsMetrics, nil
}

const (
	IbDevsDir = "/dev/infiniband"

	IbVerbsLinkLayer = "/sys/class/infiniband_verbs/uverbs%s/device/infiniband/mlx5_%s/ports/1/link_layer"
)

// GetIbDevices returns devices required for GPU allocation
func GetIbDevices(c *config.Configuration) ([]string, error) {
	devs := []string{}

	if c.AllocIbDevs {
		ibDevs, err := ioutil.ReadDir(IbDevsDir)
		if err != nil {
			return devs, fmt.Errorf("failed to 'ReadDir(%s)', err: %w", IbDevsDir, err)
		}

		if len(ibDevs) == 0 {
			return devs, fmt.Errorf("ib directory '%s' is empty", IbDevsDir)
		}

		// RESMAN-67: append only 'uverbsN' (considering their 'rdmaType') and 'rdma_cm' devices
		for _, ibDev := range ibDevs {
			ibDevName := ibDev.Name()

			// try to get N from probably 'uverbsN' device
			nIdx := len(ibDevName) - 1
			if ibDevName[:nIdx] == "uverbs" {
				ibVerbLinkLayer := fmt.Sprintf(IbVerbsLinkLayer, ibDevName[nIdx:], ibDevName[nIdx:])
				rdmaType, err := ioutil.ReadFile(ibVerbLinkLayer)
				if err != nil {
					return devs, fmt.Errorf("unable to read rdmaType from '%s', for '%s', err: %w", ibVerbLinkLayer, ibDevName, err)
				}

				rdmaTypeStr := strings.TrimSpace(string(rdmaType))
				// append uverbs depending on their 'rdmaType'
				isAppend := c.AllocIncludeIbUverbs && rdmaTypeStr == "InfiniBand" ||
					c.AllocIncludeRoceUverbs && rdmaTypeStr == "Ethernet"
				if isAppend {
					devs = append(devs, path.Join(IbDevsDir, ibDevName))
				}

				// if it's 'uverbN' device and it isn't matched to any of the 'rdmaType' above - skip it
				continue
			}

			if ibDevName == "rdma_cm" {
				devs = append(devs, path.Join(IbDevsDir, ibDevName))
			}
		}
	}

	return devs, nil
}
