package commands

import (
	"crypto/tls"
	"encoding/json"
	"fmt"
	"io/ioutil"
	"net/http"
	"os"
	"path/filepath"
	"strings"
	"time"

	"github.com/go-resty/resty/v2"
	"github.com/spf13/cobra"

	"a.yandex-team.ru/library/go/certifi"
	"a.yandex-team.ru/security/libs/go/ioatomic"
	"a.yandex-team.ru/security/yadi/yadi-os/pkg/config"
	"a.yandex-team.ru/security/yadi/yadi-os/pkg/splunk"
)

const (
	splunkSyncPeriod = 24 * time.Hour
)

var splunkArgs struct {
	FixableOnly bool
	FeedPath    string
	dbPath      string
}

var splunkCmd = &cobra.Command{
	Use:   "splunk",
	Short: "various splunk lookups",
	PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
		if err := rootPreRun(cmd, args); err != nil {
			return err
		}

		cachePath, err := filepath.Abs(splunkArgs.FeedPath)
		if err != nil {
			return fmt.Errorf("can't get absolute check path: %w", err)
		}

		splunkArgs.dbPath = filepath.Join(cachePath, "db")
		if err := os.MkdirAll(splunkArgs.dbPath, 0755); err != nil {
			return fmt.Errorf("can't create cache folder: %w", err)
		}

		err = syncFeed(splunkArgs.dbPath)
		if err != nil {
			return fmt.Errorf("can't sync feed: %w", err)
		}

		return nil
	},
}

var splunkPackagesCmd = &cobra.Command{
	Use:   "packages",
	Short: "mark vulnerable packages for splunk",
	RunE: func(_ *cobra.Command, _ []string) error {
		return splunk.ProcessPackagesLookup(
			os.Stdin,
			os.Stdout,
			splunk.WithMinSeverity(config.MinimumSeverity),
			splunk.WithFixableOnly(splunkArgs.FixableOnly),
			splunk.WithFeedURI(filepath.Join(splunkArgs.dbPath, "linux-{ecosystem}.json")),
		)
	},
}

var splunkKernelCmd = &cobra.Command{
	Use:   "kernel",
	Short: "mark vulnerable kernels for splunk",
	RunE: func(_ *cobra.Command, _ []string) error {
		return splunk.ProcessKernelLookup(
			os.Stdin,
			os.Stdout,
			splunk.WithMinSeverity(config.MinimumSeverity),
			splunk.WithFixableOnly(splunkArgs.FixableOnly),
			splunk.WithFeedURI(filepath.Join(splunkArgs.dbPath, "linux-{ecosystem}.json")),
		)
	},
}

func init() {
	splunkCmd.AddCommand(splunkPackagesCmd)
	splunkCmd.AddCommand(splunkKernelCmd)

	flags := splunkCmd.PersistentFlags()
	flags.BoolVar(&splunkArgs.FixableOnly, "fixable", splunkArgs.FixableOnly, "report only fixable issues")
	flags.StringVar(&splunkArgs.FeedPath, "cache", "/tmp/yadi-os", "path to yadi cache")

	rootCmd.AddCommand(splunkCmd)
}

func syncFeed(dbPath string) error {
	metaPath := filepath.Join(dbPath, "meta.json")
	var metaInfo struct {
		SyncedAt time.Time `json:"synced_at"`
	}
	if rawMeta, err := ioutil.ReadFile(metaPath); err == nil {
		err = json.Unmarshal(rawMeta, &metaInfo)
		if err == nil && time.Since(metaInfo.SyncedAt) < splunkSyncPeriod {
			return nil
		}
	}

	certPool, err := certifi.NewCertPool()
	if err != nil {
		return err
	}

	httpc := resty.New().
		SetBaseURL("https://yadi.yandex-team.ru").
		SetRedirectPolicy(resty.NoRedirectPolicy()).
		SetTLSClientConfig(&tls.Config{RootCAs: certPool})

	listFiles := func() ([]string, error) {
		var index struct {
			Files []string `json:"files"`
		}

		rsp, err := httpc.R().
			SetResult(&index).
			Get("/db/manifest.json")

		if err != nil {
			return nil, err
		}

		if rsp.StatusCode() != http.StatusOK {
			return nil, fmt.Errorf("failed to list yadi files, not 200 status code: %d", rsp.StatusCode())
		}

		var out []string
		for _, f := range index.Files {
			if !strings.HasSuffix(f, ".json") || !strings.HasPrefix(f, "linux-") {
				continue
			}

			out = append(out, f)
		}
		return out, nil
	}

	downloadFile := func(filename string) error {
		res, err := httpc.R().
			SetDoNotParseResponse(true).
			Get("/db/" + filename)

		if err != nil {
			return fmt.Errorf("failed to get db file %q: %w", filename, err)
		}

		body := res.RawBody()
		defer func() { _ = body.Close() }()

		if !res.IsSuccess() {
			return fmt.Errorf("failed to get db file %q: non-200 status code: %d", filename, res.StatusCode())
		}

		dbFilePath := filepath.Join(dbPath, filename)
		err = ioatomic.WriteFile(dbFilePath, res.RawBody(), 0644)
		if err != nil {
			return fmt.Errorf("failed to save file %q: %w", dbFilePath, err)
		}
		return nil
	}

	files, err := listFiles()
	if err != nil {
		return err
	}

	for _, f := range files {
		if err := downloadFile(f); err != nil {
			return err
		}
	}

	metaInfo.SyncedAt = time.Now()
	rawMeta, err := json.Marshal(metaInfo)
	if err != nil {
		return err
	}

	return ioutil.WriteFile(metaPath, rawMeta, 0600)
}
