package nvd

import (
	"compress/gzip"
	"context"
	"encoding/json"
	"io"
	"net/url"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/go-resty/resty/v2"
	"golang.org/x/net/html"

	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/libs/go/simplelog"
)

const (
	nvdIndexPage = "https://nvd.nist.gov/vuln/data-feeds"
)

type (
	FeedEntry struct {
		Description   string
		CWEID         string
		Score         float32
		PublishedDate time.Time
	}

	Feed map[string]FeedEntry
)

type (
	nvdFeedJSON struct {
		Items []nvdEntry `json:"CVE_Items"`
	}

	nvdParser struct {
		httpc *resty.Client
		feed  Feed
		ctx   context.Context
	}

	parseNvdOpts struct {
		indexPageURI string
		ctx          context.Context
	}
)

var (
	parseOnce  = new(sync.Once)
	parsedFeed Feed
)

func ParseFeed(ctx context.Context) (feed Feed, err error) {
	parseOnce.Do(func() {
		parsedFeed, err = parseNVD(parseNvdOpts{
			indexPageURI: nvdIndexPage,
			ctx:          ctx,
		})
	})

	if err != nil {
		parseOnce = new(sync.Once)
		return
	}

	return parsedFeed, nil
}

func parseNVD(opts parseNvdOpts) (Feed, error) {
	parser := nvdParser{
		httpc: resty.New(),
		feed:  make(Feed),
		ctx:   opts.ctx,
	}

	feeds, err := parser.parseFeeds(opts.indexPageURI)
	if err != nil {
		return nil, err
	}

	for _, feedURI := range feeds {
		simplelog.Info("parse NVD feed", "uri", feedURI)
		if err := parser.consumeFeed(feedURI); err != nil {
			return nil, xerrors.Errorf("failed to consume feed %q: %w", feedURI, err)
		}
	}

	parser.compact()
	return parser.feed, nil
}

func (p *nvdParser) compact() {
	for cveID, entry := range p.feed {
		if strings.Contains(entry.Description, "** REJECT **") ||
			strings.Contains(entry.Description, "** DISPUTED **") {
			// TODO(buglloc): ugly
			simplelog.Warn("skip rejected CVE", "cve_id", cveID)
			delete(p.feed, cveID)
		}
	}
}

func (p *nvdParser) consumeFeed(feedURI string) error {
	rsp, err := p.httpc.R().
		SetContext(p.ctx).
		SetDoNotParseResponse(true).
		Get(feedURI)
	if err != nil {
		return xerrors.Errorf("failed to download feed: %w", err)
	}
	body := rsp.RawBody()
	defer func() { _ = body.Close() }()

	zr, err := gzip.NewReader(body)
	if err != nil {
		return xerrors.Errorf("failed to create gzip reader: %w", err)
	}
	defer func() { _ = zr.Close() }()

	var feed nvdFeedJSON
	if err := json.NewDecoder(zr).Decode(&feed); err != nil {
		return xerrors.Errorf("failed to parse feed: %w", err)
	}

	for _, entry := range feed.Items {
		cve := entry.CVE
		if cve.Meta.ID == "" {
			simplelog.Warn("empty CVE ID in NVD")
			continue
		}

		feedEntry := FeedEntry{
			Score: entry.Score(),
		}

		if entry.PublishedDate != "" {
			feedEntry.PublishedDate, err = time.Parse("2006-01-02T15:04Z", entry.PublishedDate)
			if err != nil {
				feedEntry.PublishedDate = cveToTime(cve.Meta.ID)
				simplelog.Error(
					"failed to parse CVE published date",
					"cve_id", cve.Meta.ID,
					"published_date", entry.PublishedDate,
					"err", err.Error(),
				)
			}
		} else {
			feedEntry.PublishedDate = cveToTime(cve.Meta.ID)
		}

	loopProblem:
		for _, problemData := range cve.ProblemType.Data {
			for _, desc := range problemData.Description {
				if strings.HasPrefix(desc.Value, "CWE-") {
					feedEntry.CWEID = desc.Value
					break loopProblem
				}
			}
		}

		for _, desc := range cve.Description.Data {
			if desc.Lang != "en" {
				continue
			}

			feedEntry.Description = desc.Value
			break
		}

		p.feed[cve.Meta.ID] = feedEntry
	}

	return nil
}

func (p *nvdParser) parseFeeds(indexPageURI string) ([]string, error) {
	rsp, err := p.httpc.R().
		SetContext(p.ctx).
		SetDoNotParseResponse(true).
		Get(indexPageURI)

	if err != nil {
		return nil, err
	}

	defer func() { _ = rsp.RawBody().Close() }()

	baseURI, _ := url.Parse(indexPageURI)
	var feeds []string
	htmlScanner := html.NewTokenizer(rsp.RawBody())
	for {
		tokenType := htmlScanner.Next()
		switch tokenType {
		case html.ErrorToken:
			// Yeah! we're done!
			if err := htmlScanner.Err(); err != io.EOF {
				return nil, err
			}

			for i := len(feeds)/2 - 1; i >= 0; i-- {
				opp := len(feeds) - 1 - i
				feeds[i], feeds[opp] = feeds[opp], feeds[i]
			}

			return feeds, nil
		case html.StartTagToken:
			token := htmlScanner.Token()

			// Check if the token is an <a> tag
			if token.Data != "a" {
				continue
			}

			for _, a := range token.Attr {
				if a.Key == "href" {
					if strings.Contains(a.Val, "feeds/json/cve/") && strings.HasSuffix(a.Val, ".json.gz") {
						feedURI, err := url.Parse(a.Val)
						if err != nil {
							return nil, xerrors.Errorf("failed to parse feed uri %q: %w", a.Val, err)
						}

						feedURI = baseURI.ResolveReference(feedURI)
						feeds = append(feeds, feedURI.String())
					}
					break
				}
			}
		}
	}
}

func cveToTime(cveID string) time.Time {
	parts := strings.SplitN(cveID, "-", 3)
	if len(parts) < 3 {
		return time.Time{}
	}

	year, err := strconv.Atoi(parts[1])
	if err != nil {
		return time.Time{}
	}

	return time.Date(year, time.January, 1, 0, 0, 0, 0, time.UTC)
}
