package kernel

import (
	"fmt"
	"regexp"
	"strconv"
	"sync"

	"a.yandex-team.ru/security/yadi/yadi-os/pkg/kernel"
)

type Version struct {
	Major uint32
	Minor uint32
	Patch uint32
}

var (
	versionRegex = regexp.MustCompile(`^(\d+)\.(\d+).(\d+).*$`)
	versionOnce  sync.Once
	version      uint32
	versionErr   error
)

func CurrentVersion() (*Version, error) {
	kVersion, err := RawCurrentVersion()
	if err != nil {
		return nil, err
	}

	return &Version{
		Major: (kVersion >> 16) % 256,
		Minor: (kVersion >> 8) % 256,
		Patch: (kVersion >> 0) % 256,
	}, nil
}

func (v *Version) HasRawTracePoints() bool {
	// bpf raw tracepoints introduced in 4.17.0
	if v.Major > 4 {
		return true
	}

	return v.Major == 4 && v.Minor >= 17
}

func (v *Version) CheckPrerequisites() error {
	// gideon requires kernel >= 4.19.0
	if v.Major > 4 {
		return nil
	}

	if v.Major == 4 && v.Minor >= 19 {
		return nil
	}

	return fmt.Errorf("unsupported kernel version: %s", v)
}

func (v *Version) String() string {
	return fmt.Sprintf("%d.%d.%d", v.Major, v.Minor, v.Patch)
}

// parseKernelVersion converts a release string with format
// 4.4.2[-1] to a kernel version number in LINUX_VERSION_CODE format.
// That is, for kernel "a.b.c", the version number will be (a<<16 + b<<8 + c)
func parseKernelVersion(releaseString string) (uint32, error) {
	versionParts := versionRegex.FindStringSubmatch(releaseString)
	if len(versionParts) != 4 {
		return 0, fmt.Errorf("got invalid release version %q (expected format '4.3.2-1')", releaseString)
	}
	major, err := strconv.Atoi(versionParts[1])
	if err != nil {
		return 0, err
	}

	minor, err := strconv.Atoi(versionParts[2])
	if err != nil {
		return 0, err
	}

	patch, err := strconv.Atoi(versionParts[3])
	if err != nil {
		return 0, err
	}
	out := major*256*256 + minor*256 + patch
	return uint32(out), nil
}

func RawCurrentVersion() (uint32, error) {
	versionOnce.Do(func() {
		kernelInfo, err := kernel.Current()
		if err != nil {
			versionErr = fmt.Errorf("failed to get kernel version: %w", err)
			return
		}

		version, err = parseKernelVersion(kernelInfo.Release.String())
		if err != nil {
			versionErr = fmt.Errorf("failed to parse kernel version: %w", err)
			return
		}
	})

	return version, versionErr
}
