package vulnparser

import (
	"bytes"
	"crypto/tls"
	"encoding/json"
	"fmt"
	"net/http"
	"os"
	"os/exec"
	"path"
	"strings"

	"github.com/go-resty/resty/v2"
	"golang.org/x/mod/module"

	"a.yandex-team.ru/library/go/certifi"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/libs/go/simplelog"
	"a.yandex-team.ru/security/yadi/libs/versionarium"
	"a.yandex-team.ru/security/yadi/snatcher/pkg/feed"
)

var _ Parser = (*GolangParser)(nil)
var ErrBlacklisted = xerrors.NewSentinel("blacklisted pkgs")

var postV1Err = []byte(`invalid version: go.mod has post-v1 module path "`)
var pkgBlacklist = []string{
	"github.com/golang/go",
}

type GolangParser struct {
	pkgsCache []string
	httpc     *resty.Client
}

type goVerRange []goVersion

type goVersion struct {
	Hash       string
	Comparator string
}

type goRawVersion struct {
	comparator []rune
	hash       []rune
}

type goListResponse struct {
	Version string
}

func NewGolangParser() *GolangParser {
	httpc := resty.New().
		SetBaseURL("https://proxy.golang.org").
		SetRetryCount(3)

	certPool, err := certifi.NewCertPoolSystem()
	if err != nil {
		simplelog.Error("failed to configure TLS cert pool", "err", err)
	} else {
		httpc.SetTLSClientConfig(&tls.Config{RootCAs: certPool})
	}

	return &GolangParser{
		httpc: httpc,
	}
}

func (p *GolangParser) Parse(yadiVuln *feed.Vulnerability, snykVuln Vulnerability) error {
	for _, p := range pkgBlacklist {
		if snykVuln.Package == p || strings.HasPrefix(snykVuln.Package, p+"/") {
			return ErrBlacklisted.WithFrame()
		}
	}

	// TODO(buglloc): revert with own proxy
	//pkgName, err := p.findPkg(snykVuln.Package)
	//if err != nil {
	//	return err
	//}

	yadiVuln.Package = snykVuln.Package
	vulnerableVersions, err := parseGoRanges(fixupGoRanges(snykVuln.VulnerableVersions))
	switch {
	case len(vulnerableVersions) == 0:
		fallthrough
	case err != nil:
		simplelog.Error("failed to parse vulnerable versions",
			"pkg", snykVuln.Package,
			"vulnerable_versions", strings.Join(snykVuln.VulnerableVersions, " || "),
		)
		fallthrough
	case len(vulnerableVersions) > 0 && vulnerableVersions[0][0].Hash == "*" && len(snykVuln.HashesRange) > 0:
		// switch to hash ranges
		vulnerableVersions, err = p.hashRangesToPkgVersions(snykVuln.Package, snykVuln.HashesRange)
		if err != nil {
			return xerrors.Errorf("failed to parse hash ranges for pkg '%s': %w", snykVuln.Package, err)
		}
	}

	if len(vulnerableVersions) == 0 {
		return xerrors.Errorf("no vulnerableVersions for pkg '%s'", snykVuln.Package)
	}

	yadiVuln.VulnerableVersions = buildGoRanges(vulnerableVersions)
	return nil
}

func (p *GolangParser) findPkg(pkgName string) (string, error) {
	pkgName = strings.Trim(pkgName, "/")
	for _, newName := range p.pkgsCache {
		if pkgName == newName || strings.HasPrefix(pkgName, newName+"/") {
			return newName, nil
		}
	}

	isPkg := func(pkgName string) (bool, error) {
		escaped, err := module.EscapePath(pkgName)
		if err != nil {
			return false, err
		}

		r, err := p.httpc.R().
			Head(fmt.Sprintf("/%s/@v/list", escaped))
		if err != nil {
			return false, xerrors.Errorf("proxy request failed: %w", err)
		}

		return r.StatusCode() == http.StatusOK, nil
	}

	current := pkgName
	for {
		ok, err := isPkg(current)
		if err != nil {
			simplelog.Warn("failed to list pkg", "pkg", current, "err", err)
		}

		if ok {
			p.pkgsCache = append(p.pkgsCache, current)
			return current, nil
		}

		current, _ = path.Split(current)
		current = strings.TrimRight(current, "/")
		if current == "" {
			return "", xerrors.Errorf("pkg '%s' not found", pkgName)
		}
	}
}

func (p *GolangParser) hashRangesToPkgVersions(pkgName string, ranges []string) ([]goVerRange, error) {
	// TODO(buglloc): drop after own proxy
	pkgName, err := p.findPkg(pkgName)
	if err != nil {
		return nil, err
	}

	hashRanges, err := parseGoRanges(ranges)
	if err != nil {
		simplelog.Error("failed to parse hash ranges for pkg", "pkg", pkgName, "err", err)
		// some hash ranges may be broken, just skip it
		return nil, nil
	}

	for i, hashRange := range hashRanges {
		for k, hash := range hashRange {
			info, err := p.runGoList(pkgName, hash.Hash)
			if err != nil {
				return nil, xerrors.Errorf("failed to get hash info: %w", err)
			}

			if info.Version == "" {
				return nil, xerrors.Errorf("no version for hash %s", hash.Hash)
			}

			hashRanges[i][k].Hash = info.Version
		}
	}

	return hashRanges, nil
}

func (p *GolangParser) runGoList(pkgName string, hash string) (goListResponse, error) {
	var out goListResponse
	for {
		args := []string{"list", "-json", "-mod=readonly", "-m", fmt.Sprintf("%s@%s", pkgName, hash)}

		simplelog.Info("run go list", "command", "go "+strings.Join(args, " "))
		cmd := exec.Command("go", args...)
		var stdout bytes.Buffer
		cmd.Stdout = &stdout
		var stderr bytes.Buffer
		cmd.Stderr = &stderr
		cmd.Env = append(
			os.Environ(),
			"GO111MODULE=on",
		)

		if err := cmd.Run(); err != nil {
			errOut := stderr.Bytes()
			left := bytes.Index(errOut, postV1Err)
			right := bytes.LastIndexByte(errOut, '"')
			if right > left {
				// try one more time with fixed pkg name
				pkgName = string(errOut[left+len(postV1Err) : right])
				continue
			}
			return out, xerrors.Errorf("failed to run go list: %w", err)
		}

		err := json.NewDecoder(&stdout).Decode(&out)
		if err != nil {
			return out, xerrors.Errorf("failed to parse go list output: %w", err)
		}
		break
	}

	return out, nil
}

func (h *goRawVersion) goVersion() (goVersion, error) {
	hash := string(h.hash)
	switch {
	case len(hash) == 40:
		// real hash, not version. Fucking Snyk
		hash = strings.ToLower(hash)
	case hash != "*":
		v, err := versionarium.NewVersion("golang", strings.ReplaceAll(hash, "*", "x"))
		if err != nil {
			return goVersion{}, err
		}
		hash = v.String()
	}

	return goVersion{
		Comparator: string(h.comparator),
		Hash:       hash,
	}, nil
}

func buildGoRanges(ranges []goVerRange) string {
	var buf strings.Builder
	for i, verRange := range ranges {
		if i > 0 {
			buf.WriteString(" || ")
		}

		for k, ver := range verRange {
			if k > 0 {
				_ = buf.WriteByte(' ')
			}
			_, _ = buf.WriteString(ver.Comparator)
			_, _ = buf.WriteString(strings.TrimPrefix(ver.Hash, "v"))
		}
	}
	return buf.String()
}

func parseGoRanges(ranges []string) ([]goVerRange, error) {
	isHashVer := func(r rune) bool {
		if 'A' <= r && r <= 'Z' || 'a' <= r && r <= 'z' || '0' <= r && r <= '9' {
			return true
		}

		switch r {
		case '+', '-', '.', '*':
			return true
		}
		return false
	}

	isComparator := func(r rune) bool {
		switch r {
		case '<', '>', '=':
			return true
		}
		return false
	}

	isSpace := func(r rune) bool {
		switch r {
		case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0:
			return true
		}
		return false
	}

	out := make([]goVerRange, 0, len(ranges))
	for pos, rawRange := range ranges {
		var current goVerRange
		var hash goRawVersion
		for _, r := range rawRange {
			switch {
			case isSpace(r):
				continue
			case isHashVer(r):
				hash.hash = append(hash.hash, r)
			case isComparator(r):
				switch {
				case len(hash.hash) == 0:
					// ok, let's append it
					hash.comparator = append(hash.comparator, r)
				case len(current) >= 2:
					// range may have MULTIPLE ranges :'(
					// e.g. ">= c81e4f87c20a717b1dc52b2b77780fa789e19148 <ca0518420b931db0923c97ec17e05e150a729a64 >= 6597fdb40134965e26f715854dc85f5e6cfaa6df <e16012435f82afafdfdd7963e95d86c9e8538322"
					out = append(out, current)

					ver, err := hash.goVersion()
					if err != nil {
						return nil, xerrors.Errorf("invalid hash or version '%s': %w", string(hash.hash), err)
					}
					current = goVerRange{ver}
					hash = goRawVersion{
						comparator: []rune{r},
					}
				default:
					// new hash
					ver, err := hash.goVersion()
					if err != nil {
						return nil, xerrors.Errorf("invalid hash or version '%s': %w", string(hash.hash), err)
					}

					current = append(current, ver)
					hash = goRawVersion{
						comparator: []rune{r},
					}
				}
			default:
				return nil, xerrors.Errorf("invalid hash range char '%s' (pos: %d) in %s", string(r), pos, rawRange)
			}
		}

		if len(hash.hash) > 0 {
			ver, err := hash.goVersion()
			if err != nil {
				return nil, xerrors.Errorf("invalid hash or version '%s': %w", string(hash.hash), err)
			}
			current = append(current, ver)
		}

		if len(current) > 0 {
			out = append(out, current)
		}
	}
	return out, nil
}

func fixupGoRanges(ranges []string) []string {
	out := make([]string, 0, len(ranges))
	for _, r := range ranges {
		r = strings.TrimSpace(r)
		if r == "" {
			continue
		}
		out = append(out, r)
	}
	return out
}
