package server

import (
	"bytes"
	"context"
	"fmt"
	"net"
	"os"
	"os/exec"
	"path"
	"path/filepath"
	"strconv"
	"strings"
	"sync"
	"syscall"
	"time"

	"github.com/gofrs/uuid"
	"github.com/golang/protobuf/proto"
	opentracing "github.com/opentracing/opentracing-go"
	"go.uber.org/zap"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	pb "a.yandex-team.ru/infra/rsm/nvgpumanager/api"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/client"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/config"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/device"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/ilog"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/internal/utils"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/pkg/juggler"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/pkg/modprobe"
	"a.yandex-team.ru/infra/rsm/nvgpumanager/pkg/yasm"
	"a.yandex-team.ru/library/go/core/buildinfo"
	"a.yandex-team.ru/library/go/core/xerrors"
)

type NVGPUManager struct {
	id                   string
	pid                  int
	startTime            time.Time
	grpcServer           *grpc.Server
	grpcListener         *net.Listener
	log                  *zap.Logger
	config               *config.Configuration
	cudaRoot             string
	stop                 chan interface{}
	jugglerQueue         chan juggler.JugglerRequest
	yasmClient           *yasm.YasmClient
	pciProvider          device.PciInterface
	nvmlMode             bool
	nvmlEnabled          bool
	nvmlProvider         device.NvmlInterface
	vfioMode             config.Tristate
	vfioEnabled          bool
	vfioAPIErr           error
	pingReply            pb.PingResponse
	mux                  sync.Mutex
	cacheValid           pb.Condition
	totalDevices         int
	errorDevices         int
	unknownDevices       int
	nvmlDevices          map[string]*device.NvmlDevice
	vfioDevices          map[string]*device.VFioPciDevice
	persistencedService  *utils.Service
	fabricmanagerService *utils.Service
	fabricmanagerEnabled bool
	dcgmService          *utils.Service
	dcgmEnabled          bool
	dcgmProvider         device.DcgmInterface
	dcgmProviderInited   bool
	periodicHungTest     *utils.PeriodicTaskInfo
}

func NewNVGPUManager(log *zap.Logger, c *config.Configuration) (*NVGPUManager, error) {
	log.Info("Starting service with", zap.Any("configuration", c))

	s := &NVGPUManager{
		id:           uuid.Must(uuid.NewV4()).String(),
		log:          log.With(zap.String("service", "nvgpumgr")),
		config:       c,
		startTime:    time.Now(),
		pid:          os.Getpid(),
		stop:         make(chan interface{}),
		jugglerQueue: make(chan juggler.JugglerRequest, 10),
		yasmClient:   yasm.NewYasmClient(),
		nvmlMode:     c.NvmlMode,
		vfioMode:     c.VfioMode,
		persistencedService: utils.NewService(config.NvidiaPersistencedServiceName, config.NvidiaPersistencedDefaultArgs, config.NvidiaPersistencedPidFilePath,
			utils.DefaultServiceRestartsNrLimit, utils.DefaultServiceRestartsExpiration),
		fabricmanagerService: utils.NewService(config.NvidiaFabricmanagerServiceName, getNvidiaFabricmanagerArgs(log), config.NvidiaFabricmanagerPidFilePath,
			utils.DefaultServiceRestartsNrLimit, utils.DefaultServiceRestartsExpiration),
		dcgmService: utils.NewService(config.NvidiaDCGMServiceName, config.NvidiaDCGMDefaultArgs, config.NvidiaDCGMPidFilePath,
			utils.DefaultServiceRestartsNrLimit, utils.DefaultServiceRestartsExpiration),
		periodicHungTest: utils.NewPeriodicTaskInfo(config.PeriodicHungTestBinName, config.PeriodicHungTestBinDefaultArgs, config.PeriodicHungTestPeriod, config.PeriodicHungTestTimeout, []string{}),
	}

	// run 'udevadm settle' to wait until all gpu devices are ready
	settleCmd := exec.Command("udevadm", "settle", "-t", "60")
	settleStdout, settleStderr := new(bytes.Buffer), new(bytes.Buffer)
	settleCmd.Stdout = settleStdout
	settleCmd.Stderr = settleStderr
	ilog.Log().Info("started waiting for", zap.String("command", settleCmd.String()))
	err := settleCmd.Run()
	if err != nil {
		ilog.Log().Error("command failed",
			zap.String("command", settleCmd.String()),
			zap.Error(err),
			zap.ByteString("stdout", settleStdout.Bytes()),
			zap.ByteString("stderr", settleStderr.Bytes()))
		ilog.Log().Error("try to move on anyway")
	} else {
		ilog.Log().Info("successfully finished", zap.String("command", settleCmd.String()))
	}

	// Evaluate external resources
	info, err := os.Lstat(c.CUDARoot)
	if err == nil {
		if info.Mode()&os.ModeSymlink != 0 {
			p, err := filepath.EvalSymlinks(c.CUDARoot)
			if err == nil {
				s.cudaRoot = p
			}
		}
	}
	if s.cudaRoot == "" {
		s.log.Warn("driver root path not found, ignore it", zap.String("cuda_root", c.CUDARoot))
	}
	// Setup Device API
	if c.MockDevices != "disabled" {
		s.log.Info("Server operate in mocked devices mode")
		device.MockHWNamespace()
		s.pciProvider = &device.PciLibQemuMock{}
		s.nvmlProvider = device.NewNvmlQemuMock(s.pciProvider, c.MockDevices)

		s.dcgmProvider = device.NewDcgmMock(c.MockDevices)
		s.dcgmProviderInited = s.dcgmProvider != nil

	} else {
		s.pciProvider = &device.PciLibNvidia{}
		s.nvmlProvider = &device.NvmlLib{}
		s.dcgmProvider = &device.DcgmLib{}
	}

	if s.config.NvidiaPersistenced {
		err := s.persistencedService.Start()
		if err != nil {
			s.log.Error("failed to start nvidia-persistenced service", zap.Error(err))
		} else {
			s.log.Info("started nvidia-persistenced service")
		}
	}

	// for now there is no difference in Optional and True as upon errors
	// we are trying to live on and log error instead of crashing
	if s.config.NvidiaFabricmanager != config.False {
		dl, err := s.pciProvider.NewPciDevices()
		if err != nil {
			s.log.Error("failed to get pci device list", zap.Error(err))
		} else {
			// start Fabricmanager only if GPUs model is one of the a100 family
			a100Models := map[string]bool{"gpu_tesla_a100": true, "gpu_tesla_a100_80g": true}
			found := false
			if len(dl) > 0 {
				_, found = a100Models[dl[0].ModelName]
			} else {
				s.log.Error("no nvidia pci device found")
			}

			if found {
				err := s.fabricmanagerService.Start()
				if err != nil {
					s.log.Error("failed to start nvidia-fabricmanager service", zap.Error(err))
				} else {
					s.log.Info("started nvidia-fabricmanager service")
					s.fabricmanagerEnabled = true
				}
			} else {
				s.log.Info("not starting nvidia-fabricmanager service for detected GPU model", zap.String("ModelName", dl[0].ModelName))
			}
		}
	}

	if s.config.NvidiaDCGM != config.False {
		err := s.dcgmService.Start()
		if err != nil {
			s.log.Error("failed to start nvidia-dcgm service", zap.Error(err))
		} else {
			s.log.Info("started nvidia-dcgm service")
			s.dcgmEnabled = true
		}
	}

	if s.nvmlMode {
		_ = s.nvmlModInit()
		if s.dcgmEnabled && s.dcgmProvider != nil {
			_ = s.dcgmProviderInit()
		}
	}

	if s.vfioMode != config.False {
		_ = s.vfioModInit()
	}

	if strings.ContainsRune(c.GrpcEndpoint, ':') {
		ap := strings.Split(c.GrpcEndpoint, ":")
		if len(ap) != 2 {
			return nil, xerrors.Errorf("Bad address pair: %v", ap)
		}
		if ap[0] != "unix" {
			return nil, xerrors.Errorf("Unsupported proto: %v", ap[0])
		}
		c.GrpcEndpoint = ap[1]
	}
	if c.GrpcEndpoint == "" {
		return nil, xerrors.Errorf("Empty grpc endpoint")
	}

	var lis net.Listener
	if c.EnableSocketActivation {
		/* envPID is the pid of the wrapper, so this check fails
		envPID := os.Getenv("LISTEN_PID")

		selfPid := os.Getpid()
		pid, err := strconv.Atoi(envPID)
		if err != nil {
			return nil, xerrors.Errorf("Atoi(envPID) failed: %w", err)
		}
		if pid != selfPid {
			return nil, xerrors.Errorf("socket activation failed: %s=%d is not equal to selfPid=%d", "LISTEN_PID", pid, selfPid)
		}
		*/

		envFDs := os.Getenv("LISTEN_FDS")
		fds, err := strconv.Atoi(envFDs)
		if err != nil {
			return nil, xerrors.Errorf("Atoi(%s) failed: %w", envFDs, err)
		}
		if fds != 1 {
			return nil, xerrors.Errorf("socket activation failed: %s=%d is invalid, should be 1", "LISTEN_FDS", fds)
		}

		fileName := os.Getenv("LISTEN_FDNAMES")
		if fileName == "" {
			fileName = "LISTEN_FD_3"
		}
		lis, err = net.FileListener(os.NewFile(3, fileName))
		if err != nil {
			return nil, err
		}
	} else {
		lis, err = net.Listen("unix", c.GrpcEndpoint)
		if err != nil {
			// Stale socket suspected
			if xerrors.Is(err, syscall.EADDRINUSE) {
				if err = staleSocketCleanup(s.log, c.GrpcEndpoint); err != nil {
					return nil, err
				}
				lis, err = net.Listen("unix", c.GrpcEndpoint)
				if err != nil {
					return nil, err
				}
			} else {
				return nil, err
			}
		}
		if c.Secure {
			if err := os.Chown(c.GrpcEndpoint, config.DefaultSockOwner, config.DefaultSockGroup); err != nil {
				return nil, err
			}
			if err := os.Chmod(c.GrpcEndpoint, config.DefaultSockPermissions); err != nil {
				return nil, err
			}
		}
	}

	s.grpcListener = &lis
	s.grpcServer = NewGRPCServer()

	// Dummy span to announce service existence
	sp := opentracing.StartSpan("NVGPUManager.Init")
	sp.SetTag("component", "nvgpu-manager")
	sp.SetTag("span.kind", "server")
	defer sp.Finish()

	s.pingReply = pb.PingResponse{
		Id:          s.id,
		StartTs:     utils.TimeFormatPB(s.startTime),
		HostId:      device.HWNamespace().String(),
		VersionInfo: buildinfo.Info.ProgramVersion,
	}
	s.log.Info("NVGPUManager started",
		zap.Int("pid", s.pid),
		zap.String("id", s.id),
		zap.String("host_id", device.HWNamespace().String()),
		zap.String("version", buildinfo.Info.ProgramVersion),
		zap.String("cuda_root", s.cudaRoot))

	pb.RegisterNvGpuManagerServer(s.grpcServer, s)
	return s, nil
}

func (s *NVGPUManager) Serve(ctx context.Context) error {
	err := s.doUpdateCache()
	// If we fail to update cache from very beginning it is reasonable to stop here.
	if err != nil {
		return err
	}

	if (s.vfioMode == config.True) && s.config.ForceVfioInit {
		devList := []*device.PciDevice{}
		for _, pdev := range s.vfioDevices {
			devList = append(devList, pdev.PciDev)
		}

		s.log.Info("Try to force init all devices to vfio")
		err = s.switchDriver(ctx, "vfio-pci", devList)
		if err != nil {
			return err
		}
	}

	c := juggler.NewJugglerClient()

	go c.PushEventLoop(ctx, s.log, s.jugglerQueue)
	go s.pushMetricsLoop(ctx)
	go s.cacheEventLoop()
	return s.grpcServer.Serve(*s.grpcListener)
}

func (s *NVGPUManager) Stop(reason string) {
	s.grpcServer.Stop()
	s.grpcServer = nil

	if s.config.NvidiaPersistenced {
		err := s.persistencedService.Stop()
		if err != nil {
			s.log.Error("failed to stop nvidia-persistenced service", zap.Error(err))
		}
	}

	if s.fabricmanagerEnabled {
		err := s.fabricmanagerService.Stop()
		if err != nil {
			s.log.Error("failed to stop nvidia-fabricmanager service", zap.Error(err))
		} else {
			s.fabricmanagerEnabled = false
		}
	}

	if s.dcgmEnabled {
		if s.dcgmProviderInited {
			err := s.dcgmProvider.Shutdown()
			if err != nil {
				s.log.Error("failed to shutdown dcgmProvider", zap.Error(err))
			} else {
				s.dcgmProviderInited = false
			}
		}

		err := s.dcgmService.Stop()
		if err != nil {
			s.log.Error("failed to stop nvidia-dcgm service", zap.Error(err))
		} else {
			s.dcgmEnabled = false
		}
	}

	close(s.stop)
	close(s.jugglerQueue)
	s.log.Info("NVGPUManager service stopped",
		zap.Int("pid", s.pid),
		zap.String("id", s.id),
		zap.String("host_id", device.HWNamespace().String()),
		zap.String("version", buildinfo.Info.ProgramVersion),
		zap.String("reason", reason))
}

// unlockedUpdateCache refreshes device list cache, caller must hold s.mux.Lock()
func (s *NVGPUManager) unlockedUpdateCache() error {
	errCnt := 0
	pmap := make(map[string]*device.PciDevice)
	nvdMap := make(map[string]*device.NvmlDevice)
	vfdMap := make(map[string]*device.VFioPciDevice)

	dl, err := s.pciProvider.NewPciDevices()
	if err != nil {
		return err
	}

	for _, d := range dl {
		pmap[d.BusID] = d
	}
	s.totalDevices = len(pmap)

	if s.nvmlEnabled {
		// Update s.nvmlDevices

		var dcgmProvider device.DcgmInterface = nil
		if s.dcgmProviderInited {
			dcgmProvider = s.dcgmProvider
		}

		nvd, badDevsCnt, err := device.NewNvmlDevices(s.nvmlProvider, dcgmProvider, pmap, s.config)
		if err != nil {
			return err
		}
		for _, d := range nvd {
			if !d.Ready.Status {
				errCnt++
			}

			id := d.Device.GetUUID()
			if !d.Device.IsMigDevice() {
				id = d.PciDev.UUID
			}

			nvdMap[id] = d
			if odev, ok := s.nvmlDevices[id]; ok {
				d.TransferStatus(odev)
				continue
			}
			s.log.Info("New nvgpu discovered", zap.Any("dev", d))
		}
		errCnt += badDevsCnt
	}
	s.nvmlDevices = nvdMap

	// update s.vfioDevices here
	vfd, err := device.NewVFioPciDevices(pmap)
	if err != nil {
		return err
	}
	for _, d := range vfd {
		if !d.Ready.Status {
			errCnt++
		}
		vfdMap[d.PciDev.UUID] = d
		if odev, ok := s.vfioDevices[d.PciDev.UUID]; ok {
			d.TransferStatus(odev)
			continue
		}
		s.log.Info("New vfio discovered", zap.Any("dev", d))
	}
	s.vfioDevices = vfdMap
	// Update device stats
	s.errorDevices = errCnt
	s.unknownDevices = s.totalDevices - len(s.nvmlDevices) - len(s.vfioDevices)

	return nil
}

// doUpdateCache refresh device list cache, caller must holds s.mux.Lock()
func (s *NVGPUManager) doUpdateCache() error {
	var err error
	s.mux.Lock()
	err = s.unlockedUpdateCache()
	if err != nil {
		// It is not clear what should I do if cache update failed,
		// Lets just drop all known cache, and report error to juggler
		ilog.Log().Error("fail to update cache, drop cache", zap.Error(err))
		s.nvmlDevices = make(map[string]*device.NvmlDevice)
		s.vfioDevices = make(map[string]*device.VFioPciDevice)
	}
	s.mux.Unlock()
	ilog.Log().Debug("cache updated")
	return err
}

func newEvent(service string, desc string, st string) juggler.JugglerEvent {
	if service == "" {
		service = "gpumanager"
	}

	return juggler.JugglerEvent{
		Description: desc,
		Service:     service,
		Status:      st,
	}
}

func newRequest(events []juggler.JugglerEvent) juggler.JugglerRequest {
	return juggler.JugglerRequest{
		Source: "gpumanager",
		Events: events,
	}
}

func (s *NVGPUManager) nvmlModInit() error {
	s.log.Info("Enabling NVML")
	err := s.nvmlProvider.Init(s.config)
	if err == nil {
		s.log.Info("NVML is enabled")
		s.nvmlEnabled = true
	} else {
		err = fmt.Errorf("failed to enable NVML, %w", err)
		s.log.Error("nvmlModInit()", zap.Error(err))
	}
	return err
}

func (s *NVGPUManager) vfioModInit() error {
	s.vfioAPIErr = nil

	err := modprobe.LoadModuleIfUnloaded("vfio-pci")
	if err != nil {
		s.vfioAPIErr = fmt.Errorf("failed to load vfio-pci module, %w", err)
	}

	err = device.IOMMUFeatureProbe()
	if err != nil {
		s.vfioAPIErr = err
	}

	if s.vfioAPIErr == nil {
		s.log.Info("VFIO API is enabled")
		s.vfioEnabled = true
	} else {
		err = fmt.Errorf("failed to enable VFIO API, %w", err)
		if s.vfioMode == config.Optional {
			s.log.Debug("vfioModInit()", zap.Error(err))
		} else {
			s.log.Error("vfioModInit()", zap.Error(err))
		}
	}
	return err
}

func (s *NVGPUManager) dcgmProviderInit() error {
	err := s.dcgmProvider.Init(s.config)
	if err != nil {
		return err
	}

	err = s.dcgmProvider.UpdateDevices()
	if err != nil {
		return err
	}

	s.dcgmProviderInited = true
	s.log.Info("DcgmProvider is inited")

	return nil
}

func (s *NVGPUManager) checkInitUnlocked() error {
	var retErr, err error

	if s.config.NvidiaPersistenced {
		err = s.persistencedService.GetLastError()
		if err != nil {
			retErr = fmt.Errorf("nvidia-persistenced service failed, err: %v", err)
		}
	}

	if s.fabricmanagerEnabled {
		err = s.fabricmanagerService.GetLastError()
		if err != nil {
			retErr = fmt.Errorf("nvidia-fabricmanager service failed, err: %v", err)
		}
	}

	if s.dcgmEnabled {
		err = s.dcgmService.GetLastError()
		if err != nil {
			retErr = fmt.Errorf("nvidia-dcgm service failed, err: %v", err)
		}
	}

	if s.dcgmEnabled && s.dcgmProvider != nil && !s.dcgmProviderInited {
		err = s.dcgmProviderInit()
		if err != nil {
			retErr = err
		}
	}

	if s.nvmlMode && !s.nvmlEnabled {
		err = s.nvmlModInit()
		if err != nil {
			retErr = err
		}
	}

	if (s.vfioMode != config.False) && !s.vfioEnabled {
		err = s.vfioModInit()
		if err != nil && (s.vfioMode == config.True) {
			retErr = err
		}
	}

	return retErr
}

func (s *NVGPUManager) checkInit() error {
	s.mux.Lock()
	defer s.mux.Unlock()

	return s.checkInitUnlocked()
}

func (s *NVGPUManager) cacheEventLoop() {
	var errInit, errUpdate error
	skippedEvents := 7

	ll := ilog.Log()
	ll.Debug("cacherLoop start")

	lastCheckedTime := time.Now()

	for {
		events := []juggler.JugglerEvent{}

		errInit, errUpdate = nil, nil
		select {
		case <-s.stop:
			ll.Debug("cacherLoop stop")
			return
		case <-time.After(5 * time.Second):
			errInit = s.checkInit()
			errUpdate = s.doUpdateCache()
		}

		if errInit == nil {
			events = append(events, newEvent("gpumanager-init", "Success", "Ok"))
		} else {
			events = append(events, newEvent("gpumanager-init", "Fail, err: "+errInit.Error(), "CRIT"))
		}

		if errUpdate == nil {
			skippedEvents++
			// Send once in 30 seconds
			if skippedEvents < 6 {
				continue
			}
			skippedEvents = 0
			events = append(events, newEvent("gpumanager-rescan", "Success", "Ok"))
		} else {
			skippedEvents = 0
			events = append(events, newEvent("gpumanager-rescan", "Fail, err: "+errUpdate.Error(), "CRIT"))
		}

		switch {
		case s.unknownDevices != 0:
			events = append(events, newEvent("gpumanager-device_err", "GPU unknown device found", "WARN"))
			fallthrough
		case s.errorDevices != 0:
			events = append(events, newEvent("gpumanager-device_err", "GPU faulty device found", "CRIT"))
		default:
			events = append(events, newEvent("gpumanager-device_err", "Success", "Ok"))
		}

		if s.config.PeriodicHungTest != config.False {
			periodicHungTestErr := s.periodicHungTest.Update()
			if periodicHungTestErr != nil {
				ll.Error("Error on periodic hung test update: " + periodicHungTestErr.Error())
			} else if s.periodicHungTest.LaunchState == utils.PeriodicTaskFinished && s.periodicHungTest.LastLaunchTime != lastCheckedTime {
				lastCheckedTime = s.periodicHungTest.LastLaunchTime
				if s.periodicHungTest.LastStatus != 0 {
					events = append(events, newEvent("gpumanager-hung-test", "Periodic hung test failed: "+s.periodicHungTest.LastStderr, "CRIT"))
				} else {
					events = append(events, newEvent("gpumanager-hung-test", "Success", "Ok"))
				}
			}
		}

		genEvent := newEvent("gpumanager", "Success", "Ok")
		hasWarn := false
		for _, event := range events {
			if event.Status == "CRIT" {
				genEvent.Description = event.Description
				genEvent.Status = event.Status
				break
			}
			if event.Status == "WARN" && !hasWarn {
				genEvent.Description = event.Description
				genEvent.Status = event.Status
				hasWarn = true
			}
		}
		events = append(events, genEvent)

		select {
		case s.jugglerQueue <- newRequest(events):
		default:
			ll.Error("juggler queue overflow, drop events", zap.Any("events", events))
		}
	}
}

func (s *NVGPUManager) fillMetricsUnlocked() []yasm.YasmMetrics {
	metrics := []yasm.YasmMetrics{}

	for _, d := range s.nvmlDevices {
		metrics = append(metrics, d.YasmMetrics())
	}
	for _, d := range s.vfioDevices {
		metrics = append(metrics, d.YasmMetrics())
	}

	if s.config.IbMetrics {
		ibMetrics, err := device.GetIbYasmMetrics()
		if err != nil {
			s.log.Error("'GetIbYasmMetrics()' failed", zap.Error(err))
		} else {
			metrics = append(metrics, ibMetrics...)
		}
	}

	s.log.Debug("fill stats",
		zap.Int("total", s.totalDevices),
		zap.Int("nvml", len(s.nvmlDevices)),
		zap.Int("vfio", len(s.vfioDevices)))

	sm := yasm.YasmMetrics{
		Tags: map[string]string{"itype": "runtimecloud"},
		TTL:  30,
		Values: []yasm.YasmValue{
			yasm.YasmValue{
				Name:  "gpustat-device_count_tmmv",
				Value: s.totalDevices - s.unknownDevices,
			},
			yasm.YasmValue{
				Name:  "gpustat-error_count_tmmv",
				Value: s.errorDevices,
			},
			yasm.YasmValue{
				Name:  "gpustat-unknown_count_tmmv",
				Value: s.unknownDevices,
			},
		},
	}
	metrics = append(metrics, sm)

	return metrics
}

// doPushMetrics push metrics to yasm, grap s.mux lock
func (s *NVGPUManager) doPushMetrics(ctx context.Context) error {
	var metrics []yasm.YasmMetrics
	s.mux.Lock()
	metrics = s.fillMetricsUnlocked()
	s.mux.Unlock()

	err := s.yasmClient.SendMetrics(ctx, metrics)
	s.log.Debug("pushMetrics", zap.Any("metrics", metrics), zap.Error(err))

	return err
}

func (s *NVGPUManager) pushMetricsLoop(ctx context.Context) {
	ll := s.log
	logOp := ll.Info

	err := s.doPushMetrics(ctx)
	if err != nil {
		logOp = ll.Error
	}
	logOp("pushMetrics", zap.Error(err))
	for {
		err = nil
		select {
		case <-s.stop:
			ll.Info("stop pushMetricsLoop")
			return
		case <-ctx.Done():
			ll.Info("stop pushMetricsLoop")
			return
		case <-time.After(5 * time.Second):
			_ = s.doPushMetrics(ctx)
		}

	}
}

func (s *NVGPUManager) Ping(ctx context.Context, in *pb.Empty) (*pb.PingResponse, error) {
	reply := proto.Clone(&s.pingReply)
	return reply.(*pb.PingResponse), nil
}

func (s *NVGPUManager) ListDevices(ctx context.Context, in *pb.Empty) (*pb.ListResponse, error) {

	var devices []*pb.GpuDevice
	s.mux.Lock()
	for _, d := range s.nvmlDevices {
		devices = append(devices, d.ProtoMarshal())
	}
	for _, d := range s.vfioDevices {
		devices = append(devices, d.ProtoMarshal())
	}
	s.mux.Unlock()
	return &pb.ListResponse{Devices: devices}, nil
}

func (s *NVGPUManager) Allocate(ctx context.Context, in *pb.AllocateRequest) (*pb.AllocateResponse, error) {
	var repl *pb.AllocateResponse
	var err error

	s.mux.Lock()
	repl, err = s.doAllocate(ctx, in)
	s.mux.Unlock()

	return repl, err
}

func (s *NVGPUManager) SetDriver(ctx context.Context, in *pb.SetDriverRequest) (*pb.SetDriverResponse, error) {
	var repl *pb.SetDriverResponse
	var err error

	s.mux.Lock()
	repl, err = s.doSetDriver(ctx, in)
	s.mux.Unlock()

	return repl, err
}

func (s *NVGPUManager) doAllocate(ctx context.Context, in *pb.AllocateRequest) (*pb.AllocateResponse, error) {
	// Validate request first
	req := in.ContainerRequests
	if req == nil {
		return nil, status.Error(codes.InvalidArgument, "invalid allocation request, requst is empty")
	}
	if len(req.DevicesIDs) == 0 {
		return nil, status.Error(codes.InvalidArgument, "invalid allocation request, device list is empty")
	}
	switch req.DriverName {
	case "host":
		return s.doAllocateNVGPU(ctx, req.DevicesIDs)
	case "vfio":
		if s.vfioMode != config.False {
			if s.vfioEnabled {
				return s.doAllocateVFio(ctx, req.DevicesIDs)
			} else { // Mode && !Enabled
				return nil, status.Error(codes.Unimplemented, "vfio api is not available, err: "+s.vfioAPIErr.Error())
			}
		} else { // !Mode
			return nil, status.Error(codes.Unimplemented, "vfio api is not available: vfio mode is turned off")
		}
	case "":
		return nil, status.Error(codes.InvalidArgument, "No driver specified")
	default:
		return nil, status.Error(codes.InvalidArgument, "Unsupported driver "+req.DriverName)
	}
}

func (s *NVGPUManager) doAllocateNVGPU(ctx context.Context, devList []string) (*pb.AllocateResponse, error) {
	// Attach control devices
	specList := []*pb.DeviceSpec{}
	envDevList := ""

	cdev, err := s.nvmlProvider.GetCtlDevices(s.config)
	if err != nil {
		return nil, status.Error(codes.Internal, "can not get control devices err: "+err.Error())
	}
	for _, dPath := range cdev {
		ds := &pb.DeviceSpec{
			ContainerPath: dPath,
			HostPath:      dPath,
			Permissions:   "rw",
		}
		specList = append(specList, ds)
	}

	// Attach GPUs
	for _, id := range devList {
		dev, ok := s.nvmlDevices[id]
		if !ok {
			if _, ok := s.vfioDevices[id]; ok {
				return nil, status.Error(codes.FailedPrecondition,
					fmt.Sprintf("invalid allocation request: wrong driver mode: %s", id))
			}
			return nil, status.Error(codes.NotFound,
				fmt.Sprintf("invalid allocation request: unknown device: %s", id))
		}
		if !dev.Ready.Status {
			return nil, status.Error(codes.FailedPrecondition,
				fmt.Sprintf("invalid allocation request: device: %s is not ready", id))
		}
		for _, path := range dev.Device.GetPathes() {
			ds := &pb.DeviceSpec{
				ContainerPath: path,
				HostPath:      path,
				Permissions:   "rw",
			}
			specList = append(specList, ds)
		}
		if envDevList != "" {
			envDevList += ","
		}

		if dev.Device.IsMigDevice() {
			envDevList += dev.Device.GetUUID()
		} else {
			envDevList += fmt.Sprintf("%d", dev.Index)
		}

	}

	// Attach infiniband devices (if needed)
	ibDevs, err := device.GetIbDevices(s.config)
	if err != nil {
		return nil, status.Error(codes.Internal, "failed to get infiniband devices, err: "+err.Error())
	}
	for _, dPath := range ibDevs {
		ds := &pb.DeviceSpec{
			ContainerPath: dPath,
			HostPath:      dPath,
			Permissions:   "rw",
		}
		specList = append(specList, ds)
	}

	envs := make(map[string]string)
	bmounts := []*pb.BindMount{}

	if envDevList != "" {
		envs["NVIDIA_VISIBLE_DEVICES"] = envDevList
	}
	if s.cudaRoot != "" {
		m := &pb.BindMount{
			HostPath:      s.cudaRoot,
			ContainerPath: "/" + filepath.Base(s.cudaRoot),
			ReadOnly:      true,
		}
		envs["CUDA_LIB_PATH"] = m.ContainerPath
		bmounts = append(bmounts, m)
	}
	if s.config.AllocNvgpuUnixSocket {
		sockBind, err := s.allocNvgpuUnixSocket()
		if err != nil {
			s.log.Error("failed to alloc nvgpumanager unix socket", zap.Error(err))
		} else {
			bmounts = append(bmounts, sockBind)
		}
	}
	repl := &pb.AllocateResponse{
		ContainerResponse: &pb.ContainerAllocateResponse{
			Envs:    envs,
			Mounts:  bmounts,
			Devices: specList,
		},
	}

	return repl, nil
}

func (s *NVGPUManager) allocNvgpuUnixSocket() (*pb.BindMount, error) {
	if s.grpcListener == nil {
		return nil, fmt.Errorf("s.grpcListener == nil")
	}
	addr := (*s.grpcListener).Addr()

	if addr.Network() != "unix" {
		return nil, fmt.Errorf("nvgpumanager socket isn't a unix one")
	}

	addrStr := addr.String()
	if _, err := os.Stat(addrStr); err != nil {
		return nil, fmt.Errorf("failed to stat nvgpumanager unix socket \"%s\", stat err: %w", addrStr, err)
	}
	return &pb.BindMount{
		HostPath:      addrStr,
		ContainerPath: addrStr,
		ReadOnly:      false,
	}, nil
}

func (s *NVGPUManager) doAllocateVFio(ctx context.Context, devList []string) (*pb.AllocateResponse, error) {
	// Attach control devices
	specList := []*pb.DeviceSpec{}
	ctlDev := device.IOMMUGetCtlDevices()
	for _, p := range ctlDev {
		ds := &pb.DeviceSpec{
			ContainerPath: p,
			HostPath:      p,
			Permissions:   "rw",
		}
		specList = append(specList, ds)
	}

	requestedGpuCount := 0
	requestedGpuModel := ""
	// Attach GPUs
	for _, id := range devList {
		dev, ok := s.vfioDevices[id]
		if !ok {
			if _, ok := s.nvmlDevices[id]; ok {
				return nil, status.Error(codes.FailedPrecondition,
					fmt.Sprintf("invalid allocation request: wrong driver mode: %s", id))
			}
			return nil, status.Error(codes.NotFound,
				fmt.Sprintf("invalid allocation request: unknown device: %s", id))
		}
		if !dev.Ready.Status {
			return nil, status.Error(codes.FailedPrecondition,
				fmt.Sprintf("invalid allocation request: device: %s is not ready", id))
		}
		ds := &pb.DeviceSpec{
			ContainerPath: dev.Group.VFioPath(),
			HostPath:      dev.Group.VFioPath(),
			Permissions:   "rw",
		}
		specList = append(specList, ds)

		requestedGpuCount++
		requestedGpuModel = dev.PciDev.ModelName
	}

	/* If all 8 GPUs are requested for a VM, try to pass also:
	 * 1) all nvswitch devices
	 * 2) all ib devices (for a100_80g only)
	 */
	if requestedGpuCount == 8 {
		nvswitchDevs, err := device.GetVfioNvswitchDevices()
		if err != nil {
			return nil, status.Error(codes.Internal, "failed to get vfio nvswitch devices, err: "+err.Error())
		}

		var ibDevs []string = nil
		if requestedGpuModel == "gpu_tesla_a100_80g" { // pciids.nvidiaDevices["20b2"].Name
			ibDevs, err = device.GetVfioMlxIbDevices()
			if err != nil {
				return nil, status.Error(codes.Internal, "failed to get vfio mlx infiniband devices, err: "+err.Error())
			}
		}

		for _, dPath := range append(nvswitchDevs, ibDevs...) {
			ds := &pb.DeviceSpec{
				ContainerPath: dPath,
				HostPath:      dPath,
				Permissions:   "rw",
			}
			specList = append(specList, ds)
		}

	}

	repl := &pb.AllocateResponse{
		ContainerResponse: &pb.ContainerAllocateResponse{
			Devices: specList,
		},
	}

	return repl, nil
}

func (s *NVGPUManager) doSetDriver(ctx context.Context, in *pb.SetDriverRequest) (*pb.SetDriverResponse, error) {
	targetDriver := ""
	// Validate request first
	if len(in.DeviceId) == 0 {
		return nil, status.Error(codes.InvalidArgument, "invalid set driver request, device list is empty")
	}

	switch in.DriverName {
	case "host":
		targetDriver = s.nvmlProvider.GetDriverName()
	case "vfio":
		targetDriver = "vfio-pci"
		if !s.vfioEnabled {
			if s.vfioMode != config.False {
				return nil, status.Error(codes.Unimplemented, "vfio api is not available, err: "+s.vfioAPIErr.Error())
			} else { // !Mode
				return nil, status.Error(codes.Unimplemented, "vfio api is not available: vfio mode is turned off")
			}
		}
	case "":
		return nil, status.Error(codes.InvalidArgument, "No driver specified")
	default:
		return nil, status.Error(codes.InvalidArgument, "Unsupported driver "+in.DriverName)
	}
	pmap := make(map[string]*device.PciDevice)
	devList := []*device.PciDevice{}
	dl, err := s.pciProvider.NewPciDevices()
	if err != nil {
		return nil, status.Error(codes.Internal, "Fail to get device list, err: "+err.Error())
	}

	for _, d := range dl {
		pmap[d.UUID] = d
	}
	// Check all device ID's first
	for idx, id := range in.DeviceId {
		if id == "" {
			return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("Empty device id at: %d", idx))
		}

		// RESMAN-104: quick hack to prevent SetDriver from failing on MIG gpus
		dev, ok := s.nvmlDevices[id]
		if ok && dev.Device.IsMigDevice() {
			switch in.DriverName {
			case "host":
				continue
			default:
				return nil, status.Error(codes.Unimplemented, fmt.Sprintf("Unsupported driver for MIG gpu: %s", in.DriverName))
			}
		}

		d, ok := pmap[id]
		if !ok {
			return nil, status.Error(codes.NotFound, "Device not found, id: "+id)
		}
		if d.Driver != targetDriver {
			devList = append(devList, d)
		}
	}
	// Check if we have something to do
	if len(devList) == 0 {
		return &pb.SetDriverResponse{}, nil
	}
	err = s.switchDriver(ctx, targetDriver, devList)
	// At this point we are likely to switch some devices to new driver, force device cache update
	_ = s.checkInitUnlocked()
	err2 := s.unlockedUpdateCache()
	if err == nil {
		err = err2
	}
	return &pb.SetDriverResponse{}, err
}
func (s *NVGPUManager) switchDriver(ctx context.Context, targetDriver string, devList []*device.PciDevice) error {
	var (
		err error

		services []*utils.Service
	)

	if s.config.NvidiaPersistenced {
		services = append(services, s.persistencedService)
	}

	if s.fabricmanagerEnabled {
		services = append(services, s.fabricmanagerService)
	}

	if s.dcgmEnabled {
		services = append(services, s.dcgmService)
	}

	// FIXME:  Think about optional nvml disable here
	// Disable NVML just as a precaution, which help us to detect existing users
	if s.nvmlEnabled {
		s.log.Info("Disabling NVML", zap.Error(err))
		err = s.nvmlProvider.Shutdown()
		if err != nil {
			return status.Error(codes.Internal, "Fail do shutdown NVML, err: "+err.Error())
		}
		// Nvml will be automaticaly enabled during next hw cache update
		s.nvmlEnabled = false
	}
	// Workaround for https://st.yandex-team.ru/RESMAN-8
	if targetDriver == "vfio-pci" {

		s.log.Info("Disabling nvidia driver")
		err = s.nvmlProvider.DisableDriver(services)
		s.log.Debug("nvmlProvider.DisableDriver", zap.String("driver", "nvidia"), zap.Error(err))
		if err != nil {

			return status.Error(codes.Internal, "Fail to disable nvidia driver err: "+err.Error())
		}
	}
	grpList := []*device.IOMMUGroup{}
	for _, d := range devList {
		grp, e := d.SetDriver(ctx, targetDriver)
		s.log.Info("setDriver", zap.String("driver", targetDriver), zap.String("dev_id", d.UUID),
			zap.Error(err))

		if e != nil {
			err = status.Error(codes.Internal, fmt.Sprintf("SetDriver fail for id:%s err: %s", d.UUID, e.Error()))
			break
		}
		grpList = append(grpList, grp)
	}
	err2 := s.nvmlProvider.EnableDriver(services)
	s.log.Debug("nvmlProvider.EnableDriver", zap.Error(err2))
	if err == nil && s.nvmlMode {
		err = err2
	}
	// TODO: second probe is unnecessary in some cases
	// After all drivers are enabled again we should probe devices to attach to new driver
	for _, g := range grpList {
		err2 = g.Probe()
		if err == nil {
			err = err2
		}
	}
	return err
}

func staleSocketCleanup(log *zap.Logger, ep string) error {
	cl, err := client.NewClient(log, "unix:"+ep)
	if err == nil {
		defer cl.Close()
		if ok := cl.Ping(context.Background()); ok {
			return xerrors.Errorf("unix socket %s is busy by other instance", ep)
		}
	} else {
		if !xerrors.Is(err, syscall.ECONNREFUSED) {
			return xerrors.Errorf("Unexpected error during socketCleanup probe %w", err)
		}
	}

	log.Info("Remove stale unix socket", zap.String("endpoint", ep))
	return os.Remove(ep)
}

func getNvidiaFabricmanagerArgs(log *zap.Logger) []string {
	nvidiaFabricmanagerArgs := config.NvidiaFabricmanagerDefaultArgs

	libPath, err := utils.GetNvidiaLibraryPath()
	if err != nil {
		log.Error("failed to get lib path", zap.Error(err))
	} else {
		nvidiaFabricmanagerArgs = []string{"-c", path.Join(libPath, "nvidia-fabricmanager.cfg")}
	}

	return nvidiaFabricmanagerArgs
}
