package kernel

import (
	"encoding/json"
	"sort"

	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/libs/go/semver"
	"a.yandex-team.ru/security/libs/go/simplelog"
	"a.yandex-team.ru/security/yadi/libs/nvd"
)

type cvss3 float32

type kernelCVERef map[string]string

type kernelCVE struct {
	ID               string       `json:"-"`
	Name             string       `json:"name"`
	VendorSpecific   bool         `json:"vendor_specific"`
	AffectedVersions string       `json:"affected_versions"`
	DisclosedAt      int64        `json:"-"`
	CWE              string       `json:"cwe"`
	Fixes            string       `json:"fixes"`
	NvdText          string       `json:"nvd_text"`
	CVSS3Score       cvss3        `json:"cvss3"`
	Refs             kernelCVERef `json:"ref_urls"`
}

type kernelCVEsJSON map[string]*kernelCVE

type kernelStreamFix map[string]struct {
	FixedVersion string `json:"fixed_version"`
}

type kernelStreamFixesJSON map[string]kernelStreamFix

type Kernels struct {
	streamFixes      kernelStreamFixesJSON
	actualStreamsMap map[string]struct{}
	actualStreams    semver.Collection
}

func NewKernels(streamFixes kernelStreamFixesJSON, actualStreams []string) (*Kernels, error) {
	kernels := &Kernels{
		streamFixes:      streamFixes,
		actualStreamsMap: make(map[string]struct{}, len(actualStreams)),
	}

	for _, version := range actualStreams {
		kernels.actualStreamsMap[version] = struct{}{}
	}

	for _, streamVersion := range actualStreams {
		ver, err := semver.NewVersion(streamVersion)
		if err != nil {
			return nil, xerrors.Errorf("failed to parse stream version %q: %w", streamVersion, err)
		}
		kernels.actualStreams = append(kernels.actualStreams, ver)
	}

	sort.Sort(kernels.actualStreams)
	return kernels, nil
}

func (s *cvss3) UnmarshalJSON(data []byte) error {
	var cvss struct {
		Score float32 `json:"score,string"`
	}
	if err := json.Unmarshal(data, &cvss); err == nil {
		// Shitty feed :(
		*s = cvss3(cvss.Score)
	}
	return nil
}

func (k *Kernels) FixesForCVE(cveID string) (kernelStreamFix, bool) {
	fixes, ok := k.streamFixes[cveID]
	if !ok {
		return nil, ok
	}

	filteredFixes := make(kernelStreamFix, len(fixes))
	for streamVersion, info := range fixes {
		if _, ok := k.actualStreamsMap[streamVersion]; !ok {
			continue
		}

		filteredFixes[streamVersion] = info
	}

	return filteredFixes, true
}

func (k *Kernels) ActualStreams() semver.Collection {
	return k.actualStreams
}

func (c *kernelCVEsJSON) EnrichWithNVD(feed nvd.Feed) {
	for id, cve := range *c {
		if entry, ok := feed[id]; ok {
			if entry.Score < 1.0 {
				// No needed
				delete(*c, id)
			} else if cve.CVSS3Score < 1.0 {
				// if NVD-score is OK and CVE-score is not => fallback on NVD
				cve.CVSS3Score = cvss3(entry.Score)
			}

			if entry.Description != "" {
				cve.NvdText = entry.Description
			}

			cve.DisclosedAt = entry.PublishedDate.Unix()
		} else {
			// No needed
			delete(*c, id)
			simplelog.Warn("no CVE entry in NVD feed", "cve_id", id)
		}
	}
}
