package vulndb

import (
	"context"
	"crypto/tls"
	"fmt"
	"sort"
	"sync"
	"time"

	"github.com/go-resty/resty/v2"

	"a.yandex-team.ru/library/go/certifi"
	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/security/yadi/web/pkg/advisories"
)

/*
   {
        "cvss_score": 8.0,
        "desc": "",
        "disclosed": null,
        "external_references": [
            {
                "title": null,
                "url": "https://st.yandex-team.ru/YADI-4"
            }
        ],
        "id": "ya:1",
        "lang": "Node.js",
        "module_name": "express-yandex-csp",
        "patched_versions": ">=2.1.1",
        "reference": "https://st.yandex-team.ru/YADI-4",
        "severity": "High",
        "title": "CSP policy injection",
        "type": "nodejs",
        "vulnerable_versions": "<2.1.1"
    }
*/

type (
	DB struct {
		httpc          *resty.Client
		lock           sync.RWMutex
		ctx            context.Context
		path           string
		allAdvisories  []advisories.Advisory
		advisories     map[string]advisories.Advisory
		advisoriesStat map[string]int
		advisoryTypes  map[string]string
	}

	Options struct {
		DatabaseURL string
	}
)

const (
	AdvisoryTypeTotal = "total"
)

var (
	advisoryTypes = map[string]string{
		"python":       "Python",
		"nodejs":       "Node.JS",
		"golang":       "Golang",
		"java":         "Java",
		"linux-ubuntu": "Ubuntu",
		"linux-debian": "Debian",
		"linux-alpine": "Alpine",
		"linux-kernel": "Linux kernel",
	}
)

func New(ctx context.Context, opts Options) (*DB, error) {
	certPool, err := certifi.NewCertPool()
	if err != nil {
		return nil, fmt.Errorf("failed to init certpool: %w", err)
	}

	result := &DB{
		path: opts.DatabaseURL,
		ctx:  ctx,
		httpc: resty.New().
			SetTLSClientConfig(&tls.Config{RootCAs: certPool}).
			SetBaseURL(opts.DatabaseURL).
			SetRetryCount(5).
			SetRetryWaitTime(1 * time.Second).
			SetRetryMaxWaitTime(100 * time.Second).
			AddRetryCondition(func(rsp *resty.Response, err error) bool {
				return err != nil
			}),
	}

	err = result.fetch()
	if err != nil {
		return nil, fmt.Errorf("failed to fetch DB %s: %w", result.path, err)
	}

	result.advisoryTypes = map[string]string{}
	for n, c := range result.advisoriesStat {
		if n == AdvisoryTypeTotal {
			continue
		}

		if c <= 0 {
			continue
		}

		typeName, ok := advisoryTypes[n]
		if !ok {
			panic(fmt.Sprintf("unknown advisory type: %s", n))
		}

		result.advisoryTypes[n] = typeName
	}
	return result, nil
}

func (d *DB) DBPath() string {
	return d.path
}

func (d *DB) AdvisoryTypes() map[string]string {
	d.lock.RLock()
	defer d.lock.RUnlock()

	return d.advisoryTypes
}

func (d *DB) AdvisoriesByType(typeName string) (result []advisories.Advisory, err error) {
	d.lock.RLock()
	defer d.lock.RUnlock()

	if typeName == "all" {
		result = d.allAdvisories
		return
	}

	if _, ok := d.advisoryTypes[typeName]; !ok {
		err = xerrors.New("type not found")
		return
	}

	result = make([]advisories.Advisory, 0)
	for _, adv := range d.allAdvisories {
		if adv.Lang == typeName {
			result = append(result, adv)
		}
	}
	return
}

func (d *DB) AdvisoryByID(id string) (advisory advisories.Advisory, err error) {
	d.lock.RLock()
	defer d.lock.RUnlock()

	var ok bool
	if advisory, ok = d.advisories[id]; ok {
		return
	}

	err = xerrors.New("advisory not found")
	return
}

func (d *DB) Stats() map[string]int {
	d.lock.RLock()
	defer d.lock.RUnlock()

	return d.advisoriesStat
}

func (d *DB) fetch() error {
	var allAdvisories []advisories.Advisory
	rsp, err := d.httpc.R().
		SetContext(d.ctx).
		SetResult(&allAdvisories).
		ForceContentType("application/json").
		Get(d.path)
	if err != nil {
		return fmt.Errorf("failed to get advisories file: %w", err)
	}

	if !rsp.IsSuccess() {
		return fmt.Errorf("wrong advisories status code: %s", rsp.Status())
	}

	d.lock.Lock()
	defer d.lock.Unlock()

	sort.Slice(allAdvisories, func(i, j int) bool {
		return allAdvisories[i].DisclosedAtUnix > allAdvisories[j].DisclosedAtUnix
	})

	d.allAdvisories = allAdvisories

	d.advisoriesStat = map[string]int{
		AdvisoryTypeTotal: 0,
	}

	for advisoryType := range advisoryTypes {
		d.advisoriesStat[advisoryType] = 0
	}

	d.advisories = make(map[string]advisories.Advisory, len(d.allAdvisories))
	for i, advisory := range d.allAdvisories {
		advisory.DisclosedAt = time.Unix(advisory.DisclosedAtUnix, 0).Format("2006-01-02")
		d.advisoriesStat[AdvisoryTypeTotal]++
		d.advisoriesStat[advisory.Lang]++
		d.advisories[advisory.ID] = advisory
		d.allAdvisories[i] = advisory
	}
	return nil
}

func (d *DB) Reload() error {
	return d.fetch()
}
