//go:build linux
// +build linux

package platform

import (
	"fmt"
	"log"
	"os"
	"runtime/debug"
	"strings"
	"sync"

	"golang.org/x/sys/unix"

	"a.yandex-team.ru/security/osquery/extensions/osquery-fim/internal/container"
)

// /proc/mounts monitoring functions

func getMountPoint(path string) mountPoint {
	watcher.mu.Lock()
	defer watcher.mu.Unlock()

	mount, ok := watcher.mounts.GetParent(path)
	if !ok {
		// ERROR? No root mount point?
		return mountPoint{path: "/", filesystem: ""}
	}
	return mount.(mountPoint)
}

func getAllMountPoints() []mountPoint {
	watcher.mu.Lock()
	defer watcher.mu.Unlock()

	var ret []mountPoint
	watcher.mounts.Walk(func(path string, value interface{}) {
		ret = append(ret, value.(mountPoint))
	})
	return ret
}

type onMountChangeFn func()

func onMountChange(onChange onMountChangeFn) {
	watcher.mu.Lock()
	defer watcher.mu.Unlock()

	watcher.onChange = append(watcher.onChange, onChange)
}

type mountPoint struct {
	path       string
	filesystem string
}

type mountWatcher struct {
	mu sync.Mutex
	// Path -> mountPoint
	mounts   container.PathTrie
	onChange []onMountChangeFn
}

func (w *mountWatcher) watch(verbose bool) {
	fd, err := unix.Open("/proc/mounts", unix.O_RDONLY, 0)
	if err != nil {
		log.Printf("ERROR: could not open /proc/mounts: %v\n", err)
		return
	}
	pollFds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLPRI}}
	if verbose {
		log.Printf("Starting monitoring /proc/mounts\n")
	}
	for {
		// Separate function to allow recover()ing.
		w.watchIter(fd, pollFds, verbose)
	}
}

func (w *mountWatcher) watchIter(fd int, pollFds []unix.PollFd, verbose bool) {
	defer func() {
		if r := recover(); r != nil {
			log.Printf("ERROR: panic while watching mounts, recovering: %v\n%s", r, string(debug.Stack()))
		}
	}()

	// From https://man7.org/linux/man-pages/man5/proc.5.html
	//
	// Since kernel version 2.6.15, this file is pollable: after opening the file for reading, a change in
	// this file (i.e., a filesystem mount or unmount) causes select(2) to mark the file descriptor as
	// having an exceptional condition, and poll(2) and epoll_wait(2) mark the file as having a priority
	// event (POLLPRI).  (Before Linux 2.6.30, a change in this file was indicated by the file
	// descriptor being marked as readable for select(2), and being marked as having an error condition
	// for poll(2) and epoll_wait(2).)
	_, err := unix.Poll(pollFds, -1)
	if err != nil {
		log.Printf("ERROR: could not select() /proc/mounts: %v\n", err)
	}
	if verbose {
		log.Printf("/proc/mounts got updated\n")
	}

	err = w.fillMountPoints()
	if err != nil {
		log.Printf("ERROR: could not read /proc/mounts: %v\n", err)
		return
	}

	for _, onChange := range w.onChange {
		onChange()
	}
}

func (w *mountWatcher) fillMountPoints() error {
	// Try to read /proc/mounts in one go. 1 megabyte should be enough for everybody. Let's hope that the
	// os.Read does nothing behind our backs. See for details:
	// https://stackoverflow.com/questions/5713451/is-it-safe-to-parse-a-proc-file
	buf := make([]byte, 1024*1024)
	f, err := os.Open("/proc/mounts")
	if err != nil {
		return err
	}
	defer func() {
		_ = f.Close()
	}()
	num, err := f.Read(buf)
	if err != nil {
		return err
	}
	content := string(buf[:num])
	mounts, err := parseMountPoints(content)
	if err != nil {
		return err
	}

	w.mu.Lock()
	defer w.mu.Unlock()
	w.mounts = mounts

	// Sanity check: we've got a root mountpoint.
	if _, ok := w.mounts.GetParent("/"); !ok {
		return fmt.Errorf("ERROR: could not find / mount point in /proc/mounts:\n%s", content)
	}

	return nil
}

func parseMountPoints(content string) (container.PathTrie, error) {
	ret := container.PathTrie{}
	lines := strings.Split(string(content), "\n")
	for _, line := range lines {
		if line == "" {
			continue
		}
		fields := strings.Fields(line)
		if len(fields) < 4 {
			return container.PathTrie{}, fmt.Errorf("strange line in /proc/mounts: %s", line)
		}

		path := fields[1]
		filesystem := fields[2]
		ret.Insert(path, mountPoint{path: path, filesystem: filesystem})
	}
	return ret, nil
}
