package tracer

import (
	"fmt"

	"github.com/cilium/ebpf"
	"github.com/cilium/ebpf/link"

	"a.yandex-team.ru/security/gideon/gideon/bpf"
	"a.yandex-team.ru/security/gideon/gideon/internal/ebpfutil"
)

type (
	rawProgram struct {
		name       string
		tracepoint string
	}

	kprobeProgram struct {
		name     string
		funcName string
	}
)

func (t *Tracer) usedPrograms() []string {
	out := portoTracepoints
	for _, sys := range t.bpfProbes.rawSyscalls {
		out = append(out, fmt.Sprintf(rawSyscallTracepointFmt, sys.Name()))
	}

	for _, sys := range t.bpfProbes.kprobeSyscalls {
		out = append(out, fmt.Sprintf(seccompSyscallTracepointFmt, sys.Name()))
	}

	for _, prog := range t.bpfProbes.rawProgs {
		out = append(out, prog.name)
	}

	for _, prog := range t.bpfProbes.kprobeProgs {
		out = append(out, prog.name)
	}

	return out
}

func (t *Tracer) attachRawSyscall(syscalls ...bpf.SyscallKind) error {
	syscallsMap, ok := t.bpfCollection.Maps[rawSyscallTailMapName]
	if !ok {
		return fmt.Errorf("failed to find raw syscall tail call map: %s", rawSyscallTailMapName)
	}

	for i, sys := range syscalls {
		sysName := sys.Name()
		targetProgName := fmt.Sprintf(rawSyscallTracepointFmt, sysName)
		found := false
		for progName, prog := range t.bpfCollection.Programs {
			if prog.Type() != ebpf.RawTracepoint || progName != targetProgName {
				continue
			}

			fd := uint32(prog.FD())
			syscallID := uint32(syscalls[i])

			err := syscallsMap.Put(syscallID, fd)
			if err != nil {
				return fmt.Errorf("failed to link raw syscall %s (%d) trace program %q: %w",
					sysName, syscallID, progName, err,
				)
			}

			found = true
			break
		}

		if !found {
			return fmt.Errorf("can't find tail raw program for syscall: %s", sys.Name())
		}
	}

	return nil
}

func (t *Tracer) attachSeccompSyscall(syscalls ...bpf.SyscallKind) error {
	syscallsMap, ok := t.bpfCollection.Maps[seccompSyscallTailMapName]
	if !ok {
		return fmt.Errorf("failed to find seccomp syscall tail call map: %s", seccompSyscallTailMapName)
	}

	for i, sys := range syscalls {
		sysName := sys.Name()
		targetProgName := fmt.Sprintf(seccompSyscallTracepointFmt, sysName)
		found := false
		for progName, prog := range t.bpfCollection.Programs {
			if prog.Type() != ebpf.Kprobe || progName != targetProgName {
				continue
			}

			fd := uint32(prog.FD())
			syscallID := uint32(syscalls[i])

			err := syscallsMap.Put(syscallID, fd)
			if err != nil {
				return fmt.Errorf("failed to link seccomp syscall %s (%d) trace program %q: %w",
					sysName, syscallID, progName, err,
				)
			}

			found = true
			break
		}

		if !found {
			return fmt.Errorf("can't find tail seccomp program for syscall: %s", sys.Name())
		}
	}

	return nil
}

func (t *Tracer) attachRawBPF(requiredPrograms ...rawProgram) error {
	if len(requiredPrograms) == 0 {
		return nil
	}

	for _, req := range requiredPrograms {
		prog, ok := t.bpfCollection.Programs[req.name]
		if !ok {
			return fmt.Errorf("failed to find BPF program: %s", req.name)
		}

		if prog.Type() != ebpf.RawTracepoint {
			return fmt.Errorf("invalid BPF program %q type: %d != %d", req.name, prog.Type(), ebpf.RawTracepoint)
		}

		l, err := link.AttachRawTracepoint(link.RawTracepointOptions{
			Name:    req.tracepoint,
			Program: prog,
		})
		if err != nil {
			return fmt.Errorf("fail to attach to syscalls tracepoint(%q): %w", req, err)
		}
		t.links = append(t.links, l)
	}
	return nil
}

func (t *Tracer) attachKprobeBPF(programs ...kprobeProgram) error {
	if len(programs) == 0 {
		return nil
	}

	for _, prog := range programs {
		p, ok := t.bpfCollection.Programs[prog.name]
		if !ok {
			return fmt.Errorf("failed to find BPF program: %s", prog.name)
		}

		if p.Type() != ebpf.Kprobe {
			return fmt.Errorf("invalid BPF program %q type: %s != %s", prog.name, p.Type(), ebpf.Kprobe)
		}

		efd, err := ebpfutil.PerfEventOpenKprobe(prog.funcName, p.FD())
		if err != nil {
			return fmt.Errorf("invalid kprobe tracepoint fd %q: %w", prog.name, err)
		}

		t.fds = append(t.fds, efd)
	}
	return nil
}
