package nsenter

import (
	"runtime"

	"github.com/vishvananda/netns"
	"go.uber.org/zap"

	errors "a.yandex-team.ru/library/go/core/xerrors"
)

func WithSetNetNS(pid int, job func() error) error {
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()

	oldNS, err := netns.Get()
	if err != nil {
		return errors.Errorf("unable to get current net namespace: %w", err)
	}
	defer oldNS.Close()

	zap.S().Debugf("original net namespace:\t%s", oldNS)

	pidNS, err := netns.GetFromPid(pid)
	if err != nil {
		return errors.Errorf("unable to get net namespace of pid %d: %w", pid, err)
	}
	defer pidNS.Close()

	// We can just check new NS with old NS if pid !=1, because old NS should always be root NS.
	// If old NS != root NS, then something went wrong much earlier.
	// We must check if pid != 1 not to loose root NS in samples.
	if pidNS.Equal(oldNS) && pid != 1 {
		return errors.Errorf("oldNS == newNS: %s, %s\n", oldNS, pidNS)
	}

	if err := netns.Set(pidNS); err != nil {
		return errors.Errorf("unable to set namespace %s\tfrom PID %d: %w", pidNS, pid, err)
	}
	defer func() {
		if e := netns.Set(oldNS); e != nil {
			err = errors.Errorf("unable to reset original net namespace to %s and: %w", oldNS, err)
			return
		}
		zap.S().Debugf("namespace was reset to:\t%s", oldNS)
	}()
	zap.S().Debugf("namespace was set to:\t%s from PID %d", pidNS, pid)

	return job()
}
