package main

import (
	"encoding/xml"
	"flag"
	"fmt"
	"io/fs"
	"io/ioutil"
	"log"
	"os"
	"path/filepath"
	"strings"
	"sync"

	"google.golang.org/protobuf/proto"

	"a.yandex-team.ru/security/libs/go/archer"
	"a.yandex-team.ru/security/xray/internal/log4j-finder/log4jfinder_out"
	"a.yandex-team.ru/security/xray/pkg/xrayrpc"
	"a.yandex-team.ru/security/yadi/libs/versionarium"
)

type PomProject struct {
	XMLName xml.Name `xml:"project"`
	Parent  struct {
		XMLName xml.Name `xml:"parent"`
		Version string   `xml:"version"`
	} `xml:"parent"`
}

var subDirExcludes = map[string]struct{}{
	".arc":         {},
	".git":         {},
	".hg":          {},
	".svn":         {},
	".venv":        {},
	".venv3":       {},
	"node_modules": {},
}

var dirExcludes = map[string]struct{}{
	"/dev":  {},
	"/proc": {},
	"/sys":  {},
	"/run":  {},
	"/bin":  {},
}

var (
	versOnce  sync.Once
	vulnRange versionarium.VersionRange
)

func fatalf(msg string, a ...interface{}) {
	_, _ = fmt.Fprintf(os.Stderr, "log4j-finder: "+msg+"\n", a...)
	os.Exit(1)
}

func main() {
	var (
		target      string
		collectOnly bool
	)

	flag.StringVar(&target, "target", "/", "target dir")
	flag.BoolVar(&collectOnly, "collect-only", false, "only collect list of potentially vulnerable jars")
	flag.Parse()

	var badJars []*xrayrpc.Log4JFinderIssueDetail_Log4JInfo
	err := filepath.WalkDir(target, func(osPathname string, de fs.DirEntry, err error) error {
		if err != nil {
			return err
		}

		if de.IsDir() {
			if _, skip := subDirExcludes[de.Name()]; skip {
				return filepath.SkipDir
			}

			if _, skip := dirExcludes[osPathname]; skip {
				return filepath.SkipDir
			}
			return nil
		}

		if !strings.HasSuffix(de.Name(), ".jar") {
			return nil
		}

		log4Info, err := checkJar(osPathname)
		if err != nil {
			log.Printf("failed to check jar %q: %v\n", osPathname, err)
		}

		if log4Info != nil {
			if !collectOnly && !isVulnerableVersion(log4Info.Version) {
				return nil
			}
			badJars = append(badJars, log4Info)
		}

		return nil
	})

	if err != nil {
		fatalf("fail: %v", err)
	}

	if len(badJars) == 0 {
		// that's fine
		return
	}

	out, err := proto.Marshal(&log4jfinder_out.Result{
		Issue: &xrayrpc.Log4JFinderIssueDetail{
			Jars: badJars,
		},
	})

	if err != nil {
		fatalf("marshal failed: %s\n", err)
		return
	}

	_, _ = os.Stdout.Write(out)
}

func checkJar(jarPath string) (*xrayrpc.Log4JFinderIssueDetail_Log4JInfo, error) {
	var out *xrayrpc.Log4JFinderIssueDetail_Log4JInfo
	haveClass := false
	err := archer.Walkers[".jar"].FileWalk(
		jarPath,
		archer.FileWalkOpts{
			Once: true,
			Patterns: []archer.WalkPattern{
				{
					ID:      0,
					Marker:  "JndiLookup.class",
					Pattern: "org/apache/logging/log4j/core/lookup/JndiLookup.class",
				},
				{
					ID:      1,
					Marker:  "pom.xml",
					Pattern: "META-INF/maven/org.apache.logging.log4j/log4j-core/pom.xml",
				},
			},
		},
		func(targetPath string, id int, reader archer.SizeReader) error {
			if id == 0 {
				haveClass = true
				return nil
			}

			data, err := ioutil.ReadAll(reader)
			if err != nil {
				return fmt.Errorf("can't read pom.xml: %w", err)
			}

			var prj PomProject
			if err := xml.Unmarshal(data, &prj); err != nil {
				return fmt.Errorf("can't parse pom.xml: %w", err)
			}

			out = &xrayrpc.Log4JFinderIssueDetail_Log4JInfo{
				Path:    jarPath,
				Version: prj.Parent.Version,
			}
			return nil
		},
	)

	if err != nil {
		return nil, err
	}

	if !haveClass {
		return nil, nil
	}

	return out, nil
}

func isVulnerableVersion(ver string) bool {
	curVer, err := versionarium.NewVersion("java", ver)
	if err != nil {
		return false
	}

	versOnce.Do(func() {
		var err error
		vulnRange, err = versionarium.NewRange("java", "[2.0.0, 2.15.0)")
		if err != nil {
			panic(fmt.Sprintf("can't parse vulns range: %v", err))
		}
	})

	return vulnRange.Check(curVer)
}
