package device

import (
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"os"
	"path"
	"strconv"

	"go.uber.org/zap"

	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/ilog"
)

type IOMMUGroup struct {
	ID int
}

// IOMMUFeatureProbe check that iommu api is supported by host system kernel
func IOMMUFeatureProbe() error {
	grp, err := ioutil.ReadDir("/sys/kernel/iommu_groups")
	if err != nil {
		return errors.New("kernel has no CONFIG_IOMMU_API support")
	}
	if len(grp) == 0 {
		return errors.New("kernel iommu group feature was not enabled on boot")
	}
	return nil
}

// NewGroup create device's IOMMU group.
func NewIOMMUGroup(busID string) (*IOMMUGroup, error) {
	ll := ilog.Log()
	fname := fmt.Sprintf("/sys/bus/pci/devices/%s/iommu_group", busID)
	gPath, err := os.Readlink(fname)
	if err != nil {
		ll.Error("Fail to read iommu group", zap.String("bus_id", busID), zap.Error(err))
		return nil, err
	}

	gname := path.Base(gPath)
	gid, err := strconv.ParseInt(gname, 10, 32)
	if err != nil {
		ll.Error("Fail to farse iommu group", zap.String("bus_id", busID), zap.Error(err))
		return nil, err
	}
	gp := &IOMMUGroup{
		ID: int(gid),
	}

	return gp, nil
}

// GetCtlDevices returns devices required for device management
func IOMMUGetCtlDevices() []string {
	return []string{
		"/dev/vfio/vfio",
	}
}

type walkFunc func(busID string) error

// walk run walkFn for each member device of the group
func (g *IOMMUGroup) Walk(walkFn walkFunc) error {
	devices, err := ioutil.ReadDir(fmt.Sprintf("/sys/kernel/iommu_groups/%d/devices", g.ID))
	if err != nil {
		return err
	}

	for _, dev := range devices {
		err = walkFn(dev.Name())
		if err != nil {
			return err
		}
	}
	return err
}

//  Probes device drivers within the same IOMMU group.
func (g *IOMMUGroup) Probe() error {
	ll := ilog.Log()
	ll.Info("Probing all devices in IOMMU group", zap.Int("group_id", g.ID))

	err := g.Walk(func(busID string) error {
		err := safeWrite("/sys/bus/pci/drivers_probe", []byte(busID), 0400)
		ll.Info("Probe device", zap.String("bus_id", busID), zap.Error(err))
		if err != nil {
			return err
		}
		return nil
	})
	if err != nil {
		ll.Info("Probing failed", zap.Int("group_id", g.ID), zap.Error(err))
		return err
	}
	return nil
}

// Override binds all devices within the IOMMU group to given driver
func (g *IOMMUGroup) Override(driver string) error {
	ll := ilog.Log()
	ll.Info("Overriding all devices from IOMMU group", zap.Int("group_id", g.ID), zap.String("driver", driver))
	err := g.Walk(func(busID string) error {
		err := driverOverride(busID, driver)
		ll.Info("Overriding device driver", zap.String("bus_id", busID),
			zap.String("driver", driver), zap.Error(err))
		if err != nil {
			return err
		}

		return nil
	})
	if err != nil {
		ll.Info("Overriding failed", zap.Int("group_id", g.ID), zap.Error(err))
		return err
	}
	return nil
}

// unbindIOMMUGroup unbinds all devices within the IOMMU group from their drivers.
func (g *IOMMUGroup) Unbind(driver string) error {
	ll := ilog.Log()
	ll.Info("Unbinding all devices in IOMMU", zap.Int("group_id", g.ID))

	err := g.Walk(func(busID string) error {
		oldDriverPath, err := os.Readlink(fmt.Sprintf("/sys/bus/pci/devices/%s/driver", busID))
		if err != nil {
			ll.Info("Device has no driver", zap.String("bus_id", busID), zap.Error(err))
			// Driver unbounded already, not failure.
			return nil
		}
		origDriverName := path.Base(oldDriverPath)
		if driver != "" && driver == origDriverName {
			ll.Info("Device already has target driver", zap.String("bus_id", busID), zap.String("current driver", origDriverName))
			return nil
		}
		fname := fmt.Sprintf("/sys/bus/pci/devices/%s/driver/unbind", busID)
		err = safeWrite(fname, []byte(busID), 0400)
		ll.Info("Unbound device", zap.String("bus_id", busID),
			zap.String("old_driver", origDriverName), zap.Error(err))
		// We don't care about error here - it could be race in unbind.
		return nil
	})
	if err != nil {
		ll.Info("Unbound failed", zap.Int("group_id", g.ID))
		return err
	}

	return nil
}

func (g *IOMMUGroup) VFioPath() string {
	return fmt.Sprintf("/dev/vfio/%d", g.ID)
}

func (g *IOMMUGroup) String() string {
	return fmt.Sprintf("vfio%d", g.ID)
}

// driverOverride uses driver_override sysfs endpoint to temporarily change device driver.
func driverOverride(busID string, driver string) error {
	fname := fmt.Sprintf("/sys/bus/pci/devices/%s/driver_override", busID)

	return safeWrite(fname, []byte(driver), 0400)
}

// probe probes the device driver.
func probe(root string, busID string) error {
	fname := fmt.Sprintf("%s/sys/bus/pci/drivers_probe", root)
	return safeWrite(fname, []byte(busID), 0400)
}

// safeWrite is like ioutil.WriteFile except without O_CREATE, O_TRUNC but with O_SYNC.
func safeWrite(filename string, data []byte, perm os.FileMode) error {
	f, err := os.OpenFile(filename, os.O_WRONLY|os.O_SYNC, perm)
	if err != nil {
		return err
	}
	n, err := f.Write(data)
	if err == nil && n < len(data) {
		err = io.ErrShortWrite
	}
	if err1 := f.Close(); err == nil {
		err = err1
	}
	ilog.Log().Debug("safeWrite", zap.String("path", filename), zap.ByteString("data", data), zap.Error(err))
	return err
}
