package yadi

import (
	"bytes"
	"container/list"
	"encoding/csv"
	"fmt"
	"io"
	"log"
	"strings"

	"a.yandex-team.ru/security/yadi/yadi-os/pkg/config"
	"a.yandex-team.ru/security/yadi/yadi-os/pkg/splunk"
)

// TODO: remove
const SplunkPkgQuery = `search index="yhids" data.tag="%s"
AND data.name IN ("pack_packages_installed_packages*", "pack_installed_packages*", "pack_apps_installed*")
earliest=-12h
| fillnull value="ii" data.columns.pkg_status
| regex data.columns.pkg_status="^.i"
| stats latest(data.unixTime) AS time latest(data.columns.pkg_version) AS pkg_ver
by data.hostname data.columns.os_platform data.columns.os_codename data.columns.pkg_name data.columns.pkg_arch
| rename data.hostname AS host, data.columns.os_platform AS os_platform, data.columns.os_codename AS os_codename, data.columns.pkg_name AS pkg_name, data.columns.pkg_arch AS pkg_arch
| sort 0 host time`

type PkgChecker interface {
	GetPkgScanReader([]byte) (io.Reader, error)
	GetCheckResult(*PkgCheckerConf, io.Reader) map[string][]string
}

type Checker struct {
}

type PkgCheckerConf struct {
	MinSeverity string
	OsqueryTag  string
}

func NewPkgCheckerConf(severity string, yadiTag string) *PkgCheckerConf {
	checkerConf := new(PkgCheckerConf)
	checkerConf.MinSeverity = severity
	checkerConf.OsqueryTag = yadiTag
	return checkerConf
}

func Skip(conf *PkgCheckerConf, entry map[string]string, knownPackages map[string]bool) bool {
	if _, ok := knownPackages[entry["host"]+entry["pkg_name"]+entry["pkg_ver"]]; !ok {
		knownPackages[entry["host"]+entry["pkg_name"]+entry["pkg_ver"]] = true
	} else {
		return true
	}
	if strings.EqualFold(entry["vulnerable"], "true") &&
		(SeverityToScore(conf.MinSeverity) <= SeverityToScore(entry["vuln_severity"])) {
		return false
	}
	return true
}

func SeverityToScore(severity string) int {
	switch severity {
	case "low":
		return 0
	case "medium":
		return 1
	case "high":
		return 2
	default:
		return 2
	}
}

func (c *Checker) GetPkgScanReader(deps []byte) (io.Reader, error) {
	r := bytes.NewBuffer(deps)
	var outBuf bytes.Buffer
	err := splunk.ProcessPackagesLookup(
		r,
		&outBuf,
		splunk.WithMinSeverity(config.MinimumSeverity),
		splunk.WithFixableOnly(true),
		splunk.WithFeedURI(config.FeedURI),
	)
	return bytes.NewReader(outBuf.Bytes()), err
}

func (c *Checker) GetCheckResult(checkerConf *PkgCheckerConf, reader io.Reader) map[string][]string {
	out := make(map[string][]string)
	r := csv.NewReader(reader)
	rows := []map[string]string{}
	var header []string
	for {
		record, err := r.Read()
		if err == io.EOF {
			break
		}
		if err != nil {
			log.Fatal(err)
		}
		if header == nil {
			header = record
		} else {
			dict := map[string]string{}
			for i := range header {
				dict[header[i]] = record[i]
			}
			rows = append(rows, dict)
		}
	}
	knownPackages := make(map[string]bool)
	packagesPerHostList := make(map[string]*list.List)
	listSize := 32
	packagesPerHost := make(map[string][]string)

	for _, dict := range rows {
		hostname, _, pkgName, pkgVer := dict["host"], dict["platform"], dict["pkg_name"], dict["pkg_ver"]
		_, vulnID, vulnRef, vulnSeverity := dict["vulnerable"], dict["vuln_id"], dict["vuln_reference"], dict["vuln_severity"]
		if !Skip(checkerConf, dict, knownPackages) && vulnID != "" {
			vulnerablePkg := fmt.Sprintf("%s%s (severity %s %s)", pkgName, pkgVer, vulnSeverity, vulnRef)
			if _, ok := packagesPerHost[vulnerablePkg]; ok {
				hosts := packagesPerHost[vulnerablePkg]
				hosts = append(hosts, hostname)
				packagesPerHost[vulnerablePkg] = hosts
				listTmp := packagesPerHostList[vulnerablePkg]
				if listTmp.Len() < listSize {
					listTmp.PushBack(hostname)
					packagesPerHostList[vulnerablePkg] = listTmp
				}
			} else {
				tmp := make([]string, 0)
				tmp = append(tmp, hostname)
				packagesPerHost[vulnerablePkg] = tmp
				l := list.New()
				l.PushBack(hostname)
				packagesPerHostList[vulnerablePkg] = l
			}
		}
	}
	for pkg := range packagesPerHostList {
		hostsList := packagesPerHostList[pkg]
		hosts := make([]string, 0)
		for e := hostsList.Front(); e != nil; e = e.Next() {
			hosts = append(hosts, e.Value.(string))
		}
		out[pkg] = hosts
	}
	return out
}

func (c *Checker) GetSplunkQuery(checkerConf *PkgCheckerConf) string {
	query := fmt.Sprintf(SplunkPkgQuery, checkerConf.OsqueryTag)
	log.Printf("Running with query %s\n", query)
	return query
}
