package main

import (
	"context"
	"flag"
	"fmt"
	"go/build"
	"go/types"
	"path/filepath"
	"runtime"
	"strings"

	"github.com/golang/glog"
	"github.com/google/subcommands"

	"code.justin.tv/tshadwell/nice"
	"code.justin.tv/tshadwell/nice/naiive"
)

type TestList naiive.Tests

func (t TestList) String() string {
	var names = make([]string, len(t))
	for i, test := range t {
		names[i] = test.ShortName
	}
	return strings.Join(names, " ")
}

func (tl *TestList) Set(names string) (err error) {
	for _, shortName := range strings.Split(names, " ") {
		var t naiive.Test
		if t, err = naiive.TestByShortName(shortName); err != nil {
			return
		}

		*tl = append(*tl, t)
	}

	return
}

type ScanCommand struct {
	Targets          []string
	LongDescriptions bool
	IncludeStdLib    bool
	TestOnly         TestList
	FailOnError      bool
	Get              bool
	GetLatest        bool

	MinSeverity   nice.Severity
	MinConfidence nice.Confidence
}

func init() {
	subcommands.Register(&ScanCommand{}, "")
}

func (ScanCommand) Name() string { return "scan" }
func (s *ScanCommand) Synopsis() string {
	return `search packages and their dependencies for security issues`
}

func (s *ScanCommand) Invocation() string {
	return invocationReference(s, "", "[packages]")
}

func (s *ScanCommand) Usage() string {

	return s.Invocation() + `
   Recursively scans packages for common security vulnerabilities. When unspecified, target defaults to the current working directory.

   Targets use standard package argument syntax; they can be import paths or filesystem paths.

`
}

func (s *ScanCommand) setFlags(fs *flag.FlagSet) {
	fs.BoolVar(&s.FailOnError, "error-averse", false, "halt and print errors if files could not be acquired due to errors")
	fs.BoolVar(&s.IncludeStdLib, "stdlib", false, "also look at files that are part of the standard Go libraries")
	fs.BoolVar(&s.Get, "get", false, "attempt to acquire arguments via 'go get' beforehand")
	fs.BoolVar(&s.GetLatest, "get-latest", false, "attempt to acquire arguments via 'go get -u' beforehand")

}

func (s *ScanCommand) SetFlags(fs *flag.FlagSet) {
	s.setFlags(fs)
	fs.BoolVar(&s.LongDescriptions, "explain", false, "print out additional information on what the output means")
	s.MinSeverity = nice.MediumSeverity
	fs.Var(
		&s.MinSeverity,
		"sev-min",
		fmt.Sprintf(
			"exclude tests not meeting at least this `severity`, where severity can be %s (default: %s)",
			quotedListify("or", s.MinSeverity.AllowedValues()),
			s.MinSeverity.String(),
		),
	)

	s.MinConfidence = nice.MediumConfidence
	fs.Var(
		&s.MinConfidence,
		"conf-min",
		fmt.Sprintf(
			"exclude tests not meeting at least this `confidence`, where severity can be %s (default: %s)",
			quotedListify(" or ", s.MinConfidence.AllowedValues()),
			s.MinConfidence.String(),
		),
	)
	fs.Var(
		&s.TestOnly,
		"test-only",
		fmt.Sprintf(
			"run only the specified space-separated tests from the default set, where tests can be %s (default: run all tests)",
			quotedListify(" and / or ", naiive.TestShortNames()),
		),
	)

}

func (s *ScanCommand) Execute(ctx context.Context, fs *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
	s.Targets = fs.Args()
	if err := s.execute(ctx, fs, args...); err != nil {
		switch v := err.(type) {
		case fmt.Stringer:
			fmt.Println(v.String())
		case error:
			fmt.Printf("error: %s\n", v.Error())
		}
		return 1
	}

	return 0
}

type NoPackageError struct {
	PackageNames []string
}

func (n NoPackageError) Error() string {
	return fmt.Sprintf("could not find packages with %+q", strings.Join(n.PackageNames, " "))
}
func (n NoPackageError) String() string {
	return fmt.Sprintf("%s. you might want to try using %s to scan subpackages instead\n", n.Error(), filepath.Join(n.PackageNames[0], "..."))
}

func (s *ScanCommand) execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) (err error) {
	if s.TestOnly == nil {
		s.TestOnly = TestList(naiive.DefaultTests.ConfidenceRange(s.MinConfidence, -1).SeverityRange(s.MinSeverity, -1))
	}

	if glog.V(2) {
		var testNames = make([]string, 0, len(s.TestOnly))
		for _, test := range s.TestOnly {
			testNames = append(testNames, test.Name)
		}

		glog.V(2).Infof("running %d (%d) tests %s", len(testNames), cap(testNames), quotedListify("and", testNames))
		glog.V(4).Infof("%+v", s.TestOnly)
	}

	if len(s.Targets) < 1 {
		s.Targets = []string{"."}
	}

	var program = nice.Program{}

	var getArgs []string
	if s.GetLatest {
		getArgs = []string{"-u"} // update
	}

	switch {
	case s.GetLatest, s.Get:
		if err = program.Get(getArgs, s.Targets...); err != nil {
			return
		}
	default:
		if err = program.Import(s.Targets...); err != nil {
			return
		}
	}

	if len(program.Targets) == 0 {
		return NoPackageError{PackageNames: s.Targets}
	}

	pkg, err := build.Import(program.Targets[0], ".", build.FindOnly)
	if err != nil {
		return
	}

	cfg, err := program.Config()
	if err != nil {
		return
	}

	cfg.AllowErrors = !s.FailOnError

	cfg.TypeChecker = types.Config{
		IgnoreFuncBodies:         true,
		DisableUnusedImportCheck: true,
	}

	if strings.HasPrefix(pkg.Dir, runtime.GOROOT()) {
		glog.V(1).Infoln("looks like we're scanning the stdlib; enabling results from stdlib")
		// if they've targeted a stdlib package directory, we'll just use that
		s.IncludeStdLib = true
	}

	findings, err := (&naiive.TestConfig{IncludeStdLib: s.IncludeStdLib, Tests: naiive.Tests(s.TestOnly)}).NiceTest(&program)
	if err != nil {
		return
	}

	var text []byte
	if text, err = findings.Format(s.LongDescriptions); err != nil {
		return
	}

	if _, err = fmt.Printf("%s\n", text); err != nil {
		return
	}

	return
}
