package modprobe_test

import (
	"os"
	"strings"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"a.yandex-team.ru/infra/rsm/nvgpumanager/pkg/modprobe"
)

func TestPciLoadUnload(t *testing.T) {
	m := "pci-stub"
	err := modprobe.LoadModule(m)
	require.NoError(t, err)

	require.Equalf(t, true, modprobe.IsModuleLoaded(m),
		"Module should be loaded, lsmod:%s",
		strings.Join(modprobe.LoadedModules(), " "))

	err = modprobe.UnloadModule(m)
	require.NoError(t, err)

	require.Equal(t, false, modprobe.IsModuleLoaded(m))
}

func TestModBlacklist(t *testing.T) {
	m := "pci-stub"
	cfgName := modprobe.GenConfigPath(t.Name())

	err := modprobe.BlacklistModule(m, cfgName)
	require.NoError(t, err)
	require.FileExists(t, cfgName)
	defer os.Remove(cfgName)

	err = modprobe.LoadModule(m)
	assert.Errorf(t, err, "LoadModule should fails because it was blacklisted")
	assert.Equalf(t, false, modprobe.IsModuleLoaded(m),
		"Module should be loaded, lsmod:%s",
		strings.Join(modprobe.LoadedModules(), " "))

	err = os.Remove(cfgName)
	assert.NoError(t, err)
	err = modprobe.LoadModule(m)

	assert.NoErrorf(t, err, "LoadModule should succeed, because module no longer in blacklist")
	assert.Equalf(t, true, modprobe.IsModuleLoaded(m),
		"Module should be loaded, lsmod:%s",
		strings.Join(modprobe.LoadedModules(), " "))

	err = modprobe.UnloadModule(m)
	assert.NoError(t, err)
}
