package requirements

import (
	"bufio"
	"bytes"
	"errors"
	"fmt"
	"io/ioutil"
	"path/filepath"
	"regexp"
	"strings"

	"a.yandex-team.ru/security/libs/go/simplelog"
	"a.yandex-team.ru/security/yadi/libs/pypi"
	"a.yandex-team.ru/security/yadi/libs/pypi/pkgparser"
	"a.yandex-team.ru/security/yadi/yadi/pkg/manager"
)

const (
	lang    = "python"
	cvRegex = `v?((?:[0-9xX*]+)(?:\.[0-9xX*]+)?(?:\.[0-9xX*]+)?)[^,;:#\s]*`
)

var (
	requirementParseRe *regexp.Regexp
	versionParseRe     *regexp.Regexp
	specifierOps       = map[string]sfunc{
		"~=":  specifierCompatible,
		"==":  specifierEqual,
		"!=":  specifierNotEqual,
		"<=":  specifierLessThanEqual,
		">=":  specifierGreaterThanEqual,
		"<":   specifierLessThan,
		">":   specifierGreaterThan,
		"===": specifierArbitrary,
	}
)

type (
	sfunc func(version string) string
)

func init() {
	ops := make([]string, 0, len(specifierOps))
	for k := range specifierOps {
		ops = append(ops, regexp.QuoteMeta(k))
	}

	requirementParseRe = regexp.MustCompile(fmt.Sprintf(
		`^(?P<name>[\w.\-]+)\s*(?P<extras>\[[\w.\-,]+\])?\s*(?P<versions>(?:(?:%s)\s*(?:%s)(?:\s*,)?)+)?(?:\s*)(?:-|;|#|$)`,
		strings.Join(ops, "|"), cvRegex))

	versionParseRe = regexp.MustCompile(fmt.Sprintf(
		`^\s*(%s)\s*(%s)\s*$`,
		strings.Join(ops, "|"), cvRegex))

}

func ReadRequirementsFile(path string, withDev bool) (*Requirements, error) {
	data, err := ioutil.ReadFile(path)
	if err != nil {
		return nil, err
	}

	result, err := ParseRequirements(data, withDev, path)
	if err != nil {
		return nil, err
	}

	return result, nil
}

func ParseRequirements(data []byte, withDev bool, path string) (*Requirements, error) {
	scanner := bufio.NewScanner(bytes.NewReader(data))
	result := NewRequirements()
	extras := ""
	for scanner.Scan() {
		line := strings.TrimSpace(scanner.Text())
		if line == "" {
			continue
		}

		switch {
		case path != "" && (strings.HasPrefix(line, "-r") || strings.HasPrefix(line, "--requirement")):
			// append reference requirements
			referencePath := ""
			if strings.HasPrefix(line, "-r") {
				referencePath = strings.TrimSpace(line[2:])
			} else {
				referencePath = strings.TrimSpace(line[13:])
			}

			reqPath := filepath.Join(filepath.Dir(path), referencePath)
			if !withDev && isDev(reqPath) {
				simplelog.Debug("skip dev requirements: " + reqPath)
				continue
			}

			simplelog.Info("parse reference requirements: "+referencePath, "path", reqPath)
			reqs, err := ReadRequirementsFile(reqPath, withDev)
			if err == nil {
				_ = result.Merge(reqs)
			}
			continue
		case strings.HasPrefix(line, "#"):
			// skip comments
			continue
		case strings.HasPrefix(line, "-"):
			// skip unsupported lines
			continue
		case strings.HasPrefix(line, "["):
			// switch extras
			extras = line[1 : len(line)-1]
			continue
		}

		dependency, err := parseSingleRequirement(line)
		if err != nil {
			simplelog.Error("failed to parse requirements: "+err.Error(), "line", line)
			continue
		}

		_ = result.AddDependency(*dependency, extras)
	}

	return result, nil
}

func ParseRemote(requirements *pypi.PkgRequirements) (*Requirements, error) {
	result := NewRequirements()
	parseRequire := func(require pypi.PkgRequire, extrasName string) {
		versions, err := convertPyVersionsToSemver(require.Versions)
		if err != nil {
			simplelog.Error(
				"failed to convert python version to semver",
				"pkg_name", require.Name,
				"raw_version", require.Versions,
				"err", err,
			)
			return
		}

		depName := require.Name
		if len(require.Extras) > 0 {
			depName += fmt.Sprintf("[%s]", strings.Join(require.Extras, ","))
		}

		dep := manager.Dependency{
			Name:        depName,
			RawVersions: versions,
			Language:    lang,
		}

		_ = result.AddDependency(dep, extrasName)
	}

	for _, req := range requirements.Requires {
		parseRequire(req, "")
	}

	for extrasName, extras := range requirements.Extras {
		for _, req := range extras {
			parseRequire(req, extrasName)
		}
	}

	return result, nil
}

func ParseLocal(requires []pkgparser.Require, extras map[string][]pkgparser.Require) (*Requirements, error) {
	result := NewRequirements()
	parseRequire := func(require pkgparser.Require, extrasName string) {
		versions, err := convertPyVersionsToSemver(require.Versions)
		if err != nil {
			simplelog.Error(
				"failed to convert python version to semver",
				"pkg_name", require.Name,
				"raw_version", require.Versions,
				"err", err,
			)
			return
		}

		depName := require.Name
		if len(require.Extras) > 0 {
			depName += fmt.Sprintf("[%s]", strings.Join(require.Extras, ","))
		}

		dep := manager.Dependency{
			Name:        depName,
			RawVersions: versions,
			Language:    lang,
		}

		_ = result.AddDependency(dep, extrasName)
	}

	for _, req := range requires {
		parseRequire(req, "")
	}

	for extrasName, extras := range extras {
		for _, req := range extras {
			parseRequire(req, extrasName)
		}
	}

	return result, nil
}

func parseSingleRequirement(line string) (*manager.Dependency, error) {
	res := requirementParseRe.FindStringSubmatch(line)
	if len(res) < 1 {
		return nil, errors.New("failed to parse line")
	}

	name := strings.TrimSpace(res[1])
	extras := strings.TrimSpace(res[2])
	rawVersions := strings.TrimSpace(res[3])
	versions, err := convertPyVersionsToSemver(rawVersions)
	if err != nil {
		return nil, err
	}

	dep := &manager.Dependency{
		Name:        name + extras,
		RawVersions: versions,
		Language:    lang,
	}

	return dep, nil
}

func convertPyVersionsToSemver(rawVersions string) (string, error) {
	if rawVersions == "" {
		// Version less
		rawVersions = ">0.0.0"
	}

	_versions := strings.Split(rawVersions, ",")
	versions := make([]string, len(_versions))
	for i, ver := range _versions {
		v := versionParseRe.FindStringSubmatch(ver)
		if len(v) < 2 {
			return "", errors.New("failed to parse version specifier")
		}

		comparer := v[1]
		version := v[2]
		if specifier, ok := specifierOps[comparer]; ok {
			versions[i] = specifier(version)
		} else {
			return "", fmt.Errorf("unknown version specifier: %s", comparer)
		}
	}

	return strings.Join(versions, " "), nil
}

func specifierCompatible(version string) string {
	splitted := strings.Split(version, ".")
	max := len(splitted) - 1
	if len(splitted) > 3 {
		max = 2
	}

	rightBound := strings.Join(splitted[:max], ".")
	rightBound += ".*"

	return fmt.Sprintf(">=%s =%s", version, rightBound)
}

func specifierEqual(version string) string {
	return fmt.Sprintf("=%s", version)
}

func specifierNotEqual(version string) string {
	return fmt.Sprintf("!=%s", version)
}

func specifierLessThanEqual(version string) string {
	return fmt.Sprintf("<=%s", version)
}

func specifierGreaterThanEqual(version string) string {
	return fmt.Sprintf(">=%s", version)
}

func specifierLessThan(version string) string {
	return fmt.Sprintf("<%s", version)
}

func specifierGreaterThan(version string) string {
	return fmt.Sprintf(">%s", version)
}

func specifierArbitrary(version string) string {
	return fmt.Sprintf("=%s", version)
}

func isDev(path string) bool {
	return strings.Contains(filepath.Base(path), "dev")
}
