package main

import (
	"bufio"
	"flag"
	"fmt"
	"io"
	"io/fs"
	"log"
	"os"
	"path/filepath"
	"strings"

	"google.golang.org/protobuf/proto"

	"a.yandex-team.ru/security/libs/go/archer"
	"a.yandex-team.ru/security/xray/internal/spring4shell-finder/spring4shellfinder_out"
	"a.yandex-team.ru/security/xray/internal/spring4shell-finder/vuln"
	"a.yandex-team.ru/security/xray/pkg/xrayrpc"
)

type Pkg struct {
	Name    string
	Version string
}

var subDirExcludes = map[string]struct{}{
	".arc":         {},
	".git":         {},
	".hg":          {},
	".svn":         {},
	".venv":        {},
	".venv3":       {},
	"node_modules": {},
}

var dirExcludes = map[string]struct{}{
	"/dev":  {},
	"/proc": {},
	"/sys":  {},
	"/run":  {},
	"/bin":  {},
}

var targetModules = map[string]struct{}{
	"spring-beans": {},
	"spring.beans": {},
}

func fatalf(msg string, a ...interface{}) {
	_, _ = fmt.Fprintf(os.Stderr, "spring4shell-finder: "+msg+"\n", a...)
	os.Exit(1)
}

func main() {
	var (
		target      string
		collectOnly bool
	)

	flag.StringVar(&target, "target", "/", "target dir")
	flag.BoolVar(&collectOnly, "collect-only", false, "only collect list of potentially vulnerable jars")
	flag.Parse()

	var badJars []*xrayrpc.Spring4ShellFinderFindingDetail_Spring4ShellInfo
	err := filepath.WalkDir(target, func(osPathname string, de fs.DirEntry, err error) error {
		if err != nil {
			return err
		}

		if de.IsDir() {
			if _, skip := subDirExcludes[de.Name()]; skip {
				return filepath.SkipDir
			}

			if _, skip := dirExcludes[osPathname]; skip {
				return filepath.SkipDir
			}
			return nil
		}

		switch filepath.Ext(de.Name()) {
		case ".jar", ".war":
		default:
			return nil
		}

		javaInfo, err := checkJarFile(osPathname)
		if err != nil {
			log.Printf("failed to check jar %q: %v\n", osPathname, err)
		}

		if javaInfo != nil {
			if !collectOnly && !vuln.IsVulnerable(javaInfo) {
				return nil
			}
			badJars = append(badJars, javaInfo)
		}

		return nil
	})

	if err != nil {
		fatalf("fail: %v", err)
	}

	if len(badJars) == 0 {
		// that's fine
		return
	}

	out, err := proto.Marshal(&spring4shellfinder_out.Result{
		Issue: &xrayrpc.Spring4ShellFinderFindingDetail{
			Jars: badJars,
		},
	})

	if err != nil {
		fatalf("marshal failed: %s\n", err)
		return
	}

	_, _ = os.Stdout.Write(out)
}

func checkJarFile(jarPath string) (*xrayrpc.Spring4ShellFinderFindingDetail_Spring4ShellInfo, error) {
	f, err := archer.OpenFile(jarPath)
	if err != nil {
		return nil, err
	}
	defer func() { _ = f.Close() }()

	return checkJar(f, jarPath)
}

func checkJar(r archer.SizeReader, jarPath string) (*xrayrpc.Spring4ShellFinderFindingDetail_Spring4ShellInfo, error) {
	var (
		notice, manifest, embed *xrayrpc.Spring4ShellFinderFindingDetail_Spring4ShellInfo
		haveBeans               bool
	)

	err := archer.Walkers[filepath.Ext(jarPath)].Walk(r,
		archer.FileWalkOpts{
			Once: true,
			Patterns: []archer.WalkPattern{
				{
					ID:      0,
					Marker:  "notice.txt",
					Pattern: "META-INF/notice.txt",
				},
				{
					ID:      1,
					Marker:  "MANIFEST.MF",
					Pattern: "META-INF/MANIFEST.MF",
				},
				{
					ID:      2,
					Marker:  "package-info.class",
					Pattern: "org/springframework/beans/package-info.class",
				},
				{
					ID:      3,
					Marker:  "spring-beans",
					Pattern: "BOOT-INF/lib/spring-beans-*.jar",
				},
				{
					ID:      4,
					Marker:  "spring-beans",
					Pattern: "WEB-INF/lib/spring-beans-*.jar",
				},
			},
		},
		func(targetPath string, id int, reader archer.SizeReader) (err error) {
			switch id {
			case 0:
				notice, err = checkNotice(jarPath, reader)
			case 1:
				manifest, err = checkManifest(jarPath, reader)
			case 2:
				haveBeans = true
			case 3, 4:
				embed, err = checkJar(reader, jarPath)
			default:
				err = fmt.Errorf("unexptected pattern id: %d", id)
			}
			return err
		},
	)

	if err != nil {
		return nil, err
	}

	if embed != nil {
		return embed, nil
	}

	if haveBeans && manifest == nil && notice != nil {
		return notice, nil
	}

	return manifest, nil
}

func checkNotice(jarPath string, reader io.Reader) (*xrayrpc.Spring4ShellFinderFindingDetail_Spring4ShellInfo, error) {
	scanner := bufio.NewScanner(reader)
	for scanner.Scan() {
		line := scanner.Text()
		if !strings.HasPrefix(line, "Spring Framework ") {
			continue
		}

		return &xrayrpc.Spring4ShellFinderFindingDetail_Spring4ShellInfo{
			Path:    jarPath,
			Type:    "notice",
			Name:    "Spring Framework",
			Version: strings.TrimSpace(line[17:]),
		}, nil
	}

	return nil, nil
}

func checkManifest(jarPath string, reader io.Reader) (*xrayrpc.Spring4ShellFinderFindingDetail_Spring4ShellInfo, error) {
	manifest, err := parseManifest(reader)
	if err != nil {
		return nil, fmt.Errorf("unable to parse manifest: %w", err)
	}

	if _, ok := targetModules[manifest.Name]; !ok {
		return nil, nil
	}

	return &xrayrpc.Spring4ShellFinderFindingDetail_Spring4ShellInfo{
		Path:    jarPath,
		Type:    "manifest",
		Name:    manifest.Name,
		Version: manifest.Version,
	}, nil
}

func parseManifest(r io.Reader) (Pkg, error) {
	var pkg Pkg
	scanner := bufio.NewScanner(r)
	for scanner.Scan() {
		parts := strings.SplitN(scanner.Text(), ":", 2)
		if len(parts) != 2 {
			continue
		}

		key, value := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
		switch key {
		case "Implementation-Title":
			pkg.Name = value
		case "Automatic-Module-Name":
			pkg.Name = value
		case "Implementation-Version":
			pkg.Version = value
		}
	}

	return pkg, scanner.Err()
}
