/*
	Package nice implements a static analysis framework for Go.
*/
package nice // import "code.justin.tv/tshadwell/nice"

import (
	"bytes"
	"errors"
	"fmt"
	"go/token"
	"os"
	"os/exec"
	"reflect"
	"sort"
	"strings"
	"testing"
	"text/template"

	"golang.org/x/tools/go/callgraph/rta"
	"golang.org/x/tools/go/loader"
	"golang.org/x/tools/go/ssa"
	"golang.org/x/tools/go/ssa/ssautil"
)

//Program lazily collects static information on a Go program.
//a new Progam is a ready to use Program; nice will automatically
//load the Program in the current directory (which is the package being
//tested by go test).
type Program struct {
	Targets []string // list of target packages or directories, defaults to current working directory
	ldcfg   *loader.Config
	ld      *loader.Program

	// Ssa contains the single static assignment analysis form of this program
	// this field is only available once LoadSsa() is called.
	Ssa *Ssa
}

// Ssa is a helper type exposing single static analysis helper funcitons.
// SSA is a high-level form of code analysis.
type Ssa struct {
	*ssa.Program
	functions  map[*ssa.Function]bool
	callgraphs map[*ssa.Function]*rta.Result
}

// Returns all the functions contained in the ssa representation of the program.
func (p Ssa) Functions() (fns map[*ssa.Function]bool, err error) {
	if p.functions == nil {
		p.functions = ssautil.AllFunctions(p.Program)
	}

	return p.functions, nil
}

// SsaPosition returns the Pos of a given ssa value
// it ascends in the ssa graph until it can find a valid position.
func (p Ssa) ValuePos(v ssa.Value) (pos token.Pos, err error) {
	if v.Pos().IsValid() {
		pos = v.Pos()
		return
	}

	var parent ssa.Value
	if parent = v.Parent(); parent == nil {
		err = fmt.Errorf("cannot get position value for %+v", v)
		return
	}

	return p.ValuePos(parent)
}

func (p Ssa) ValuePosition(v ssa.Value) (position token.Position, err error) {
	pos, err := p.ValuePos(v)
	if err != nil {
		return
	}

	return p.Position(pos)
}

func (p Ssa) NodePos(n ssa.Node) (pos token.Pos, err error) {
	if n.Pos().IsValid() {
		pos = n.Pos()
		return
	}

	var parent ssa.Node
	if parent = n.Parent(); parent == nil {
		err = fmt.Errorf("cannot get position value for %+v", n)
	}

	return n.Parent().Pos(), nil
}

func (p Ssa) NodePosition(n ssa.Node) (position token.Position, err error) {
	pos, err := p.NodePos(n)
	if err != nil {
		return
	}

	return p.Position(pos)
}

func (p Ssa) Position(pos token.Pos) (position token.Position, err error) {
	return p.Fset.Position(pos), nil
}

func MustPosition(possiblePosition token.Position, err error) (position token.Position) {
	if err != nil {
		panic(err)
	}

	return possiblePosition
}

func (p Ssa) Callgraph(root *ssa.Function) (result *rta.Result, err error) {
	var ok bool
	if result, ok = p.callgraphs[root]; !ok {
		if p.callgraphs[root] = rta.Analyze([]*ssa.Function{root}, true); err != nil {
			return
		}
	}

	return p.callgraphs[root], err
}

type Findings []Finding

var _ sort.Interface = Findings{}

func (f Findings) Len() int           { return len(f) }
func (f Findings) Less(i, j int) bool { return f[i].String() < f[j].String() }
func (f Findings) Swap(i, j int)      { f[i], f[j] = f[j], f[i] }

/*
func (fl FindingList) Error() string {
	var names = make([]string, len(fl))
	for i, f := range fl {
		names[i] = f.TestName()
	}

	return fmt.Sprintf("multiple tests failed %s", strings.Join(names, " "))
}

*/

var FindingTemplate = template.Must(template.New("nice-finding").Parse(`
{{if .LongDescription}}
	{{range .Batch}}
	{{.Describer.TestName}} [confidence {{.Describer.TestConfidence}}] [severity {{.Describer.TestSeverity}}]
		{{.Describer.TestDescription}}
		{{range .Findings}}
			{{with index .Position 0}}{{.}}
			{{.Line}}: {{end}}{{.Interesting}}
		{{end}}
	{{end}}
{{else}}
	{{range .Batch}}
		{{.Describer.TestName}} [confidence {{.Describer.TestConfidence}}] [severity {{.Describer.TestSeverity}}]
		{{range .Findings}}
			{{with index .Position 0}}{{.}}
			{{.Line}}: {{end}}{{.Interesting}}
		{{end}}
	{{end}}
{{end}}
`))

func (fl Findings) MarshalText() (text []byte, err error) {
	return fl.Format(false)
}

type Report struct {
}

func (fl Findings) Format(longDescription bool) (text []byte, err error) {
	if len(fl) < 1 {
		return []byte("(no findings)"), nil
	}

	type TemplateInput struct {
		LongDescription bool
		Batch           []struct {
			Describer
			Findings
		}
	}

	var buf bytes.Buffer

	if err = FindingTemplate.Execute(&buf, TemplateInput{
		LongDescription: longDescription,
		Batch:           fl.Sort(),
	}); err != nil {
		return
	}

	text = buf.Bytes()
	return

}

// used to implement internals for FindingList.Sort()
type testBatch []struct {
	Describer
	Findings
}

var _ sort.Interface = testBatch{}

func (t testBatch) Len() int      { return len(t) }
func (t testBatch) Swap(i, j int) { t[i], t[j] = t[j], t[i] }
func (t testBatch) Less(i, j int) bool {
	if t[i].Describer == nil || t[j].Describer == nil {
		panic(fmt.Sprintf("some tests are missing describers ?? %+v, %+v", t[i], t[j]))
	}
	return t[i].Describer.TestSeverity() < t[j].Describer.TestSeverity()
}

//Sort batches the Findings into the test they come from (by interface target)
//then sorts the Findings in position string order, and the batches of Findings in
//order of test serverity
func (fl Findings) Sort() (results []struct {
	Describer
	Findings
}) {
	m := make(map[string]Findings, len(fl))
	for _, finding := range fl {
		m[finding.Describer.TestName()] = append(m[finding.Describer.TestName()], finding)
	}

	results = make([]struct {
		Describer
		Findings
	}, 0, len(m))

	// this is ugly as hell and i can clear this the fuck up

	for _, v := range m {
		// this happens and i dont know why
		sort.Sort(v)
		results = append(results, struct {
			Describer
			Findings
		}{v[0].Describer, v})
	}

	// this should shuffle everything in place, even for results which has a different type
	// check it https://play.golang.org/p/dZrufGKwb7O
	sort.Sort(testBatch(results))

	return
}

// Describer is an interface type describing what this finding means
type Describer interface {
	TestName() string
	TestDescription() string
	TestConfidence() Confidence
	TestSeverity() Severity
}

type Finding struct {
	Position [2]token.Position // begin, end

	Describer
}

//Interesting returns whatever is between the two Positions in this Finding,
//and the remainder of the line.
func (f Finding) Interesting() (partial string, err error) {
	file, err := os.Open(f.Position[0].Filename)
	if err != nil {
		return
	}

	if _, err = file.Seek(int64(f.Position[0].Offset), 0); err != nil {
		return
	}

	var text = make([]byte, f.Position[1].Offset-f.Position[0].Offset)

	if _, err = file.Read(text); err != nil {
		return
	}

	partial = string(text)
	return
}

func (f Finding) String() string { return fmt.Sprintf("%s %s", f.Position, f.TestName()) }

type Tester interface {
	NiceTest(p *Program) (Findings, error)
}

var DefaultProgram Program

func (p *Program) Test(t *testing.T, nice Tester) {
	if err := p.test(t, nice); err != nil {
		t.Fatal(err)
	}
}

func (p *Program) test(t *testing.T, nice Tester) (err error) {
	findings, err := nice.NiceTest(p)
	if err != nil {
		return
	}

	if len(findings) > 0 {
		var text []byte
		if text, err = findings.MarshalText(); err != nil {
			return
		}

		return errors.New(string(text))
	}

	return
}

func Test(t *testing.T, nice Tester, pkgs ...string) {
	DefaultProgram.Test(t, nice)
}

type Errors []error

func (c Errors) errorStrings() []string {
	var s []string = make([]string, len(c))
	for i, e := range c {
		s[i] = e.Error()
	}

	return s
}
func (c Errors) Error() string {
	return strings.Join(c.errorStrings(), ", ")
}

func (c Errors) String() string {
	errorStrings := c.errorStrings()
	for i, s := range errorStrings {
		errorStrings[i] = "\t" + s
	}

	return fmt.Sprintf("multiple errors: \n%s", strings.Join(errorStrings, "\n"))
}

type writerFunc func(p []byte) (n int, err error)

func (w writerFunc) Write(p []byte) (n int, err error) { return w(p) }

func (p *Program) Config() (config *loader.Config, err error) {
	if p.ldcfg == nil {
		p.ldcfg = new(loader.Config)

		if len(p.Targets) == 0 {
			if err = p.Import("."); err != nil {
				return
			}
		}

		if _, err = p.ldcfg.FromArgs(p.Targets, false); err != nil {
			return
		}
	}

	return p.ldcfg, nil
}

//Ast lazily computes and returns the *loader.Program, which contains information
//about the files that make up this Program, the Ast and the go packages they represent.
//
//The package loaded is assumed to be the current working directory (which is the tested
//package when invoked by `go test`)
func (p *Program) Ast() (program *loader.Program, err error) {
	if p.ld == nil {
		var config *loader.Config
		config, err = p.Config()
		if err != nil {
			return
		}

		if p.ld, err = config.Load(); err != nil {
			return
		}
	}

	return p.ld, nil
}

// Ssa loads the single static assignment form of the Program and populates .Ssa
// SSA is a high-level static analysis form beyond the AST; see golang.org/x/tools/go/ssa
// for more information on this representation.
func (p *Program) LoadSsa() (err error) {
	var ld *loader.Program
	if ld, err = p.Ast(); err != nil {
		return
	}

	p.Ssa = &Ssa{Program: ssautil.CreateProgram(ld, 0)}

	p.Ssa.Build()

	return nil
}

type ErrAnonymousType struct{ reflect.Type }

func (e ErrAnonymousType) Error() string {
	return fmt.Sprintf("cannot Locate reference to anonymous type %s", e.Type)
}

type ErrLocatingValue struct{ reflect.Type }

func (e ErrLocatingValue) Error() string {
	return fmt.Sprintf("cannot Locate type %s (path: %s, name %s)", e.Type, e.PkgPath(), e.Name())
}

// Get `go get`s the the given package(s), and `.Import`s them.
// Supports triple-dot and other standard package get syntax.
func (p *Program) Get(args []string, paths ...string) (err error) {
	if err = exec.Command("go", append(append([]string{"get"}, args...), paths...)...).Run(); err != nil {
		return
	}

	return p.Import(paths...)

}

func (p *Program) Import(paths ...string) (err error) {
	cmd := exec.Command("go", append([]string{"list", "--"}, paths...)...)
	bt, err := cmd.Output()
	if err != nil {
		if _, ok := err.(*exec.ExitError); ok {
			err = nil
		} else {
			return
		}
	}

	for _, pkg := range bytes.Fields(bt) {
		p.Targets = append(p.Targets, string(pkg))
	}

	return
}

//Locate locates a symbol passed to it
func (p Ssa) LocateType(value interface{}) (typ *ssa.Type, err error) {
	t := reflect.Indirect(reflect.ValueOf(value)).Type()
	path, name := t.PkgPath(), t.Name()
	switch "" {
	case name:
		err = ErrAnonymousType{t}
		return
	case path:
		path = "builtin" // perhaps 'string' 'interface' etc.
	}

	m := p.ImportedPackage(path).Members[name]

	if m == nil {
		err = ErrLocatingValue{t}
	}

	typ = m.(*ssa.Type)

	return
}
