package device

import (
	"fmt"
	"os"
	"path"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"a.yandex-team.ru/infra/rsm/nvgpumanager/pkg/modprobe"
)

func init() {
	if err := modprobe.LoadModule("pci-stub"); err != nil {
		panic("modprobe.LoadModule fail, err:" + err.Error())
	}
	if err := modprobe.LoadModule("vfio-pci"); err != nil {
		panic("modprobe.LoadModule fail, err:" + err.Error())
	}
}
func TestIOMMUFeatureProbe(t *testing.T) {
	err := IOMMUFeatureProbe()

	assert.NoErrorf(t, err, "Test env must have iommu_group feature enabled")
}

func TestIOMMUGroupCreate(t *testing.T) {
	busID := "0000:00:05.0"

	grp, err := NewIOMMUGroup(busID)
	require.NoError(t, err)
	require.Equal(t, &IOMMUGroup{ID: 5}, grp)
}

func TestIOMMUBind(t *testing.T) {
	busID := "0000:00:05.0"
	driver := "vfio-pci"

	grp, err := NewIOMMUGroup(busID)
	require.NoError(t, err)
	require.Equal(t, &IOMMUGroup{ID: 5}, grp)

	err = grp.Override(driver)
	require.NoErrorf(t, err, "Override failed")

	err = grp.Unbind(driver)
	require.NoErrorf(t, err, "Unbound failed")

	err = grp.Probe()
	require.NoErrorf(t, err, "Probe failed")

	drvPath, err := os.Readlink(fmt.Sprintf("/sys/bus/pci/devices/%s/driver", busID))
	assert.NoError(t, err)
	newDriver := path.Base(drvPath)
	assert.Equal(t, "vfio-pci", newDriver)

	// Fall back to original module
	driver = "pci-stub"
	err = grp.Override(driver)
	assert.NoErrorf(t, err, "Override failed")
	err = grp.Unbind(driver)
	assert.NoErrorf(t, err, "Unbound failed")
	err = grp.Probe()
	assert.NoErrorf(t, err, "Probe failed")
}
