package device

import (
	"context"
	"errors"
	"fmt"
	"io/ioutil"
	"os"
	"path"
	"strconv"
	"strings"

	"github.com/gofrs/uuid"
	opentracing "github.com/opentracing/opentracing-go"
	"go.uber.org/zap"

	pb "a.yandex-team.ru/infra/rsm/nvgpumanager/api"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/ilog"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/pciids"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/utils"
)

type PciDevice struct {
	UUID         string
	BusID        string
	Class        string
	Vendor       string
	Device       string
	SubSysVendor string
	SubSysDevice string
	ModelName    string
	MemoryGb     uint32
	Driver       string
	NumaNode     int32
	CurLinkSpeed float32
}

// PciInterface : Type to reprensent interactions with PCI devices
type PciInterface interface {
	NewPciDevices() ([]*PciDevice, error)
	NewPciDev(devpath string) (*PciDevice, error)
}

var (
	pciDB map[string]pciids.Vendor
)

func init() {
	pciDB = pciids.NewIDs()
	_, ok := pciDB["10de"]
	if !ok {
		panic("Can not initalize Nvidia device DB")
	}
	_, ok = pciDB["1af4"]
	if !ok {
		panic("Can not initalize Redhat device DB")
	}
}

// GenUUIDWithNamespace return hashed uuid from BusID + HWID
func (d *PciDevice) GenUUIDWithNamespace(ns uuid.UUID) {
	name := fmt.Sprintf("%s-%x-%x-%x-%x", d.BusID, d.Vendor, d.Device, d.SubSysVendor, d.SubSysDevice)
	d.UUID = uuid.NewV5(ns, name).String()
}

func (d *PciDevice) GenUUID() {
	d.GenUUIDWithNamespace(machineUUID)
}

func (d *PciDevice) ProtoMarshal() *pb.PciDeviceSpec {
	return &pb.PciDeviceSpec{
		VendorId:         d.Vendor,
		DeviceId:         d.Device,
		SubsysVendorId:   d.SubSysVendor,
		SubsysDeviceId:   d.SubSysDevice,
		ModelName:        d.ModelName,
		MemorySizeGb:     d.MemoryGb,
		NumaNode:         d.NumaNode,
		DriverName:       d.Driver,
		CurrentLinkSpeed: d.CurLinkSpeed,
	}
}

func newPciDev(devpath string) (*PciDevice, error) {
	model := "unknown"
	mem := uint32(0)

	vendor, err := sysfsReadID(path.Join(devpath, "vendor"))
	if err != nil {
		return nil, err
	}
	devClass, err := sysfsReadID(path.Join(devpath, "class"))
	if err != nil {
		return nil, err
	}
	device, err := sysfsReadID(path.Join(devpath, "device"))
	if err != nil {
		return nil, err
	}
	subSysVendor, err := sysfsReadID(path.Join(devpath, "subsystem_vendor"))
	if err != nil {
		return nil, err
	}
	subSysDevice, err := sysfsReadID(path.Join(devpath, "subsystem_device"))
	if err != nil {
		return nil, err
	}
	numa, err := sysfsReadInt(path.Join(devpath, "numa_node"))
	if err != nil {
		return nil, err
	}
	curLinkSpeed, err := sysfsReadFloatWoUnits(path.Join(devpath, "current_link_speed"))
	if err != nil {
		switch {
		case os.IsNotExist(err): // PCI device may not have 'current_link_speed' file
			curLinkSpeed = 0
		case errors.Is(err, strconv.ErrSyntax): // PCI device may have "Unknown speed" in 'current_link_speed' file
			curLinkSpeed = 0
		default:
			return nil, err
		}
	}
	drvName := ""
	drvSymlink := path.Join(devpath, "driver")
	if _, err := os.Stat(drvSymlink); err == nil {
		dp, err := os.Readlink(drvSymlink)
		if err != nil {
			return nil, err
		}
		drvName = path.Base(dp)
	}
	if v, ok := pciDB[vendor]; ok {
		if d, ok := v.Devices[device]; ok {
			model = d.Name
			mem = d.MemoryGb
		}
	}
	dev := &PciDevice{
		BusID:        path.Base(devpath),
		Class:        devClass,
		Vendor:       vendor,
		Device:       device,
		SubSysVendor: subSysVendor,
		SubSysDevice: subSysDevice,
		NumaNode:     int32(numa),
		MemoryGb:     mem,
		ModelName:    model,
		Driver:       drvName,
		CurLinkSpeed: curLinkSpeed,
	}
	dev.GenUUID()
	return dev, nil
}

func newPciDevices(basePath string, filterVendor string, filterClass map[string]string) ([]*PciDevice, error) {
	ll := ilog.Log()
	var devlist []*PciDevice

	ll.Debug("readdir", zap.String("dirpath", basePath))
	files, err := ioutil.ReadDir(basePath)
	if err != nil {
		ll.Error("fail to read sysfs path", zap.Error(err))
		return nil, err
	}
	for _, info := range files {
		ll.Debug("walk", zap.String("name", info.Name()))
		devpath := path.Join(basePath, info.Name())
		if filterVendor != "" {
			vendor, err := sysfsReadID(path.Join(devpath, "vendor"))
			if err != nil {
				return nil, err
			}
			if vendor != filterVendor {
				ll.Debug("skip", zap.String("name", info.Name()), zap.String("vendor", vendor))
				continue
			}
		}
		if len(filterClass) != 0 {
			skip := true
			class, err := sysfsReadID(path.Join(devpath, "class"))
			if err != nil {
				return nil, err
			}
			if _, ok := filterClass[class]; ok {
				skip = false
			}
			if skip {
				ll.Debug("skip", zap.String("name", info.Name()), zap.String("class", class))
				continue
			}
		}

		dev, err := newPciDev(devpath)
		if err != nil {
			return nil, err
		}
		devlist = append(devlist, dev)
	}
	return devlist, nil
}

func sysfsRead(p string) (string, error) {
	data, err := ioutil.ReadFile(p)
	if err != nil {
		ilog.Log().Error("Could not read", zap.String("path", p), zap.Error(err))
		return "", err
	}
	s := string(data)
	if idx := strings.IndexRune(s, '\n'); idx != -1 {
		s = s[:idx]
	}
	ilog.Log().Debug("read sysfs", zap.String("path", p), zap.String("data", s))
	return strings.TrimSpace(s), nil
}

func sysfsReadID(p string) (string, error) {
	s, err := sysfsRead(p)
	if err != nil {
		return "", err
	}
	if !strings.HasPrefix(s, "0x") {
		return "", fmt.Errorf("fail to parse ID, path:%s idStr:%s", p, s)
	}
	return s[2:], nil
}

func sysfsReadUint64(p string) (value uint64, err error) {
	s, err := sysfsRead(p)
	if err != nil {
		return 0, err
	}

	return strconv.ParseUint(s, 10, 64)
}

func sysfsReadInt(p string) (value int, err error) {
	s, err := sysfsRead(p)
	if err != nil {
		return 0, err
	}

	return strconv.Atoi(s)
}

// e.g. 2.5 GT/s
func sysfsReadWithoutUnits(p string) (string, error) {
	s, err := sysfsRead(p)
	if err != nil {
		return "", err
	}
	s = strings.Fields(s)[0]

	return s, nil
}

func sysfsReadFloatWoUnits(p string) (float32, error) {
	s, err := sysfsReadWithoutUnits(p)
	if err != nil {
		return 0, err
	}

	f64, err := strconv.ParseFloat(s, 32)

	return float32(f64), err
}

func (d *PciDevice) doSetDriver(ctx context.Context, driver string) (*IOMMUGroup, error) {
	grp, err := NewIOMMUGroup(d.BusID)
	if err != nil {
		return nil, err
	}
	if err = grp.Override(driver); err != nil {
		return nil, err
	}
	if err = grp.Unbind(driver); err != nil {
		return nil, err
	}
	if err = grp.Probe(); err != nil {
		return nil, err
	}
	return grp, nil
}

func (d *PciDevice) SetDriver(ctx context.Context, driver string) (*IOMMUGroup, error) {
	sp := opentracing.StartSpan("pcidev.SetDriver/" + driver)
	sp.SetTag("component", "nvgpu-manager")
	sp.SetTag("span.kind", "server")

	grp, err := d.doSetDriver(ctx, driver)
	utils.SpanCheckError(sp, err)
	return grp, err
}
