package kernel

import (
	"context"
	"encoding/json"
	"fmt"
	"io/ioutil"
	"path/filepath"
	"strings"
	"time"

	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/libs/go/semver"
	"a.yandex-team.ru/security/libs/go/simplelog"
	"a.yandex-team.ru/security/yadi/libs/nvd"
	"a.yandex-team.ru/security/yadi/snatcher/internal/gitutil"
	"a.yandex-team.ru/security/yadi/snatcher/pkg/feed"
)

type (
	Opts struct {
		TmpDir string
	}

	Feed struct {
		tmpDir string
	}
)

const (
	trackerURI     = "https://github.com/nluedtke/linux_kernel_cves.git"
	cveURL         = "https://www.linuxkernelcves.com/cves/%s"
	fixURL         = "https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/commit?id=%s"
	feedName       = "linux-kernel-cves"
	unknownVersion = "unk"
	pkgName        = "linux"
)

var (
	platforms = map[feed.Platform]string{
		feed.KernelPlatform: feed.KernelPlatform,
	}

	isValidRef = map[string]struct{}{
		"Debian":    {},
		"ExploitDB": {},
		"NVD":       {},
		"Ubuntu":    {},
	}

	// TODO(buglloc): find where we can get actual releases. For now it's just copy-pasted from: https://www.kernel.org/category/releases.html
	actualKernels = []string{
		"4.9",
		"4.14",
		"4.19",
		"5.4",
		"5.10",
		"5.15",
	}
)

func NewFeed(opts Opts) (Feed, error) {
	return Feed{
		tmpDir: opts.TmpDir,
	}, nil
}

func (f Feed) Name() string {
	return feedName
}

func (f Feed) GetPlatformByAlias(alias string) (feed.Platform, error) {
	for p, a := range platforms {
		if a == alias {
			return p, nil
		}
	}
	return "", xerrors.New("not supported platform")
}

func (f Feed) Dump(ctx context.Context, opts feed.DumpingOpts) (feed.Result, error) {
	kernelRepo, err := gitutil.NewRepo(trackerURI, f.tmpDir)
	if err != nil {
		return nil, xerrors.Errorf("failed to create git repository: %w", err)
	}
	defer func() { _ = kernelRepo.Clean() }()

	repoPath, err := kernelRepo.CloneWithContext(ctx, 1)
	if err != nil {
		return nil, xerrors.Errorf("failed to clone git repository: %w", err)
	}

	nvdFeed, err := nvd.ParseFeed(ctx)
	if err != nil {
		return nil, xerrors.Errorf("failed to parse NVD Feed: %w", err)
	}

	var cves kernelCVEsJSON
	if err := parseJSONFile(filepath.Join(repoPath, "data", "kernel_cves.json"), &cves); err != nil {
		return nil, xerrors.Errorf("failed to parse CVEs: %w", err)
	}
	cves.EnrichWithNVD(nvdFeed)

	var kernelStreamFixes kernelStreamFixesJSON
	if err := parseJSONFile(filepath.Join(repoPath, "data", "stream_fixes.json"), &kernelStreamFixes); err != nil {
		return nil, xerrors.Errorf("failed to parse stream fixes: %w", err)
	}

	kernels, err := NewKernels(kernelStreamFixes, actualKernels)
	if err != nil {
		return nil, xerrors.Errorf("failed to create kernels info: %w", err)
	}

	result := feed.Result{feed.KernelPlatform: {}}
	if err := processEOL(kernels, &result); err != nil {
		return nil, xerrors.Errorf("failed to process EOL kernels: %w", err)
	}

	for cveID, cveInfo := range cves {
		if cveInfo.VendorSpecific {
			simplelog.Info("linux-kernel: skip vendor specific CVE", "cve_id", cveID)
			continue
		}

		cveInfo.ID = cveID
		if err := processCVE(cveInfo, kernels, &result); err != nil {
			return nil, xerrors.Errorf("failed to parse CVE %q: %w", cveID, err)
		}
	}

	return result, nil
}

func processCVE(cve *kernelCVE, kernels *Kernels, results *feed.Result) error {
	if cve.AffectedVersions == "" {
		simplelog.Warn("linux-kernel: failed to parse kernel CVE: empty affected versions", "cve_id", cve.ID)
		return nil
	}

	if strings.Contains(cve.NvdText, "Android") {
		// TODO(buglloc): ugly
		// Skip Android vulnerabilities
		return nil
	}

	vulnerableVersions, err := genVulnerableVersions(cve, kernels)
	if err != nil {
		return xerrors.Errorf("failed to parse vulnerable versions: %w", err)
	}

	if vulnerableVersions == "" {
		simplelog.Warn(
			"linux-kernel: failed to parse kernel CVE: empty vulnerable versions",
			"cve_id", cve.ID,
			"affected_versions", cve.AffectedVersions,
		)
		return nil
	}

	vulnerability := feed.Vulnerability{
		Title:              cve.ID,
		Package:            pkgName,
		ID:                 strings.ToUpper(strings.TrimPrefix(cve.ID, "CVE-")),
		SrcType:            feedName,
		RichDescription:    true,
		Language:           feed.KernelPlatform,
		VulnerableVersions: vulnerableVersions,
		DisclosedAt:        cve.DisclosedAt,
		CvssScore:          float32(cve.CVSS3Score),
	}

	switch {
	case cve.Name != "":
		vulnerability.Title += ": " + cve.Name
	case cve.CWE != "":
		vulnerability.Title += ": " + cve.CWE
	}

	vulnerability.YadiID = strings.ToUpper(fmt.Sprintf("yadi-%s-%s", feed.KernelPlatform, vulnerability.ID))
	// accept only sha1 hash as commit id
	if len(cve.Fixes) == 40 {
		vulnerability.References = append(vulnerability.References, feed.Reference{
			Title: fmt.Sprintf("Fix commit: %s", cve.Fixes),
			URL:   fmt.Sprintf(fixURL, cve.Fixes),
		})
	}

	for title, url := range cve.Refs {
		if _, ok := isValidRef[title]; !ok {
			continue
		}

		if url == "" {
			continue
		}

		vulnerability.References = append(vulnerability.References, feed.Reference{
			Title: title,
			URL:   url,
		})
	}

	vulnerability.References = append(vulnerability.References, feed.Reference{
		Title: "Linux Kernel CVEs",
		URL:   fmt.Sprintf(cveURL, cve.ID),
	})

	var description strings.Builder
	description.WriteString("## Overview")
	description.WriteRune('\n')
	description.WriteString(cve.NvdText)
	if cve.AffectedVersions != "" {
		description.WriteString("\n\n")
		description.WriteString("## Original reported affected versions\n")
		description.WriteString(strings.ReplaceAll(cve.AffectedVersions, unknownVersion, "*"))
	}

	if len(vulnerability.References) > 0 {
		description.WriteString("\n\n")
		description.WriteString("## References\n")
		for _, ref := range vulnerability.References {
			if ref.Title != "" {
				description.WriteString(fmt.Sprintf("  - [%s](%s)\n", ref.Title, ref.URL))
			} else {
				description.WriteString(fmt.Sprintf("  - [%s](%s)\n", ref.URL, ref.URL))
			}
		}
	}

	vulnerability.Description = description.String()
	(*results)[feed.KernelPlatform][vulnerability.ID] = vulnerability
	return nil
}

func processEOL(kernels *Kernels, results *feed.Result) error {
	if len(kernels.ActualStreams()) == 0 {
		return xerrors.Errorf("empty actual streams")
	}

	vulnerableVersions := make([]string, len(kernels.ActualStreams()))
	description := `## Overview
Using a non-LTS kernels doesn't guarantee its security and stability.
You should update the kernel to the nearest LTS release:
`
	for i, version := range kernels.ActualStreams() {
		description += fmt.Sprintf("  * %s\n", version.Original)
		vulnerableVersions[i] = fmt.Sprintf("!%s.*", version.Original)
	}

	vulnerability := feed.Vulnerability{
		Title:              "Non-LTS kernel",
		Package:            pkgName,
		ID:                 "NON-LTS",
		DisclosedAt:        time.Now().Unix(),
		YadiID:             strings.ToUpper(fmt.Sprintf("yadi-%s-non-lts", feed.KernelPlatform)),
		SrcType:            feedName,
		Description:        description,
		RichDescription:    true,
		Language:           feed.KernelPlatform,
		VulnerableVersions: strings.Join(vulnerableVersions, " "),
		CvssScore:          9.9,
	}

	vulnerability.References = append(vulnerability.References, feed.Reference{
		Title: "Active kernel releases",
		URL:   "https://www.kernel.org/category/releases.html",
	})

	(*results)[feed.KernelPlatform][vulnerability.ID] = vulnerability
	return nil
}

func genVulnerableVersions(cve *kernelCVE, kernels *Kernels) (string, error) {
	lowerRawVersion, upperRawVersion := parseAffectedVersions(cve.AffectedVersions)
	if lowerRawVersion == "" || upperRawVersion == "" {
		return "", xerrors.Errorf("unknown affected versions format: %s", cve.AffectedVersions)
	}

	lowerVersion, err := semver.NewVersion(lowerRawVersion)
	if err != nil && lowerRawVersion != unknownVersion {
		return "", xerrors.Errorf("failed to parse lower version %q: %w", lowerRawVersion, err)
	}

	upperVersion, err := semver.NewVersion(upperRawVersion)
	if err != nil && upperRawVersion != unknownVersion {
		return "", xerrors.Errorf("failed to parse lower version %q: %w", upperRawVersion, err)
	}

	fixes, _ := kernels.FixesForCVE(cve.ID)
	var (
		streams = kernels.ActualStreams()
		result  strings.Builder
		last    = len(streams) - 1
	)
	for i, streamVersion := range streams {
		if lowerVersion != nil && lowerVersion.Compare(streamVersion) >= 0 {
			continue
		}

		if upperVersion != nil && upperVersion.Compare(streamVersion) <= 0 {
			break
		}

		fix, ok := fixes[streamVersion.Original]
		if !ok {
			// If we doesn't have fix - mark all stream as vulnerable
			result.WriteString(fmt.Sprintf("%d.%d.* || ", streamVersion.Major, streamVersion.Minor))
			continue
		}

		if strings.Contains(fix.FixedVersion, "-rc") {
			// skip whole stream if vulnerability was fixed before release
			continue
		}

		if i == 0 && lowerVersion != nil && streamVersion.LessThan(lowerVersion) {
			result.WriteString(">=")
			result.WriteString(lowerRawVersion)
		} else {
			result.WriteString(">=")
			result.WriteString(streamVersion.Original)
		}

		if i == last && upperVersion != nil && upperVersion.GreaterThan(streamVersion) && upperVersion.Major == streamVersion.Major && upperVersion.Minor == streamVersion.Minor {
			result.WriteString(" <=")
			result.WriteString(upperRawVersion)
		} else {
			result.WriteString(" <")
			result.WriteString(fix.FixedVersion)
			result.WriteString(" || ")
		}
	}

	vulnerableVersions := strings.TrimSuffix(result.String(), " || ")
	return vulnerableVersions, nil
}

func parseAffectedVersions(affectedVersions string) (lower, upper string) {
	bounds := strings.Split(affectedVersions, " to ")
	if len(bounds) != 2 {
		return
	}

	lower = strings.TrimPrefix(bounds[0], "v")
	upper = strings.TrimPrefix(bounds[1], "v")
	return
}

func parseJSONFile(path string, target interface{}) error {
	data, err := ioutil.ReadFile(path)
	if err != nil {
		return xerrors.Errorf("read: %w", err)
	}

	if err = json.Unmarshal(data, target); err != nil {
		return xerrors.Errorf("parse: %w", err)
	}

	return nil
}
