package commands

import (
	"encoding/hex"
	"errors"
	"fmt"
	"os"
	"path/filepath"

	"github.com/klauspost/compress/zstd"
	"github.com/spf13/cobra"
	"github.com/stripe/krl"

	"a.yandex-team.ru/security/skotty/skotty/pkg/sshutil"
)

var parseCmd = &cobra.Command{
	Use:          "parse [/path/to/krl]",
	SilenceUsage: false,
	Short:        "parse SSH KRL",
	RunE: func(_ *cobra.Command, args []string) error {
		if len(args) == 0 {
			return errors.New("no krl file specified")
		}

		doParse := func(filename string) error {
			in, err := os.ReadFile(filename)
			if err != nil {
				return err
			}

			switch filepath.Ext(filename) {
			case ".zst":
				r, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1))
				if err != nil {
					return err
				}
				defer r.Close()

				in, err = r.DecodeAll(in, nil)
				if err != nil {
					return err
				}
			}

			parsedKRL, err := krl.ParseKRL(in)
			if err != nil {
				return err
			}

			if len(parsedKRL.SigningKeys) == 0 {
				fmt.Println("signed by: none")
			} else {
				fmt.Println("signed by:")
				for _, key := range parsedKRL.SigningKeys {
					fmt.Println("  - ", sshutil.Fingerprint(key))
				}
			}

			fmt.Println("sections:")
			secNum := 0
			for _, section := range parsedKRL.Sections {
				switch s := section.(type) {
				case *krl.KRLCertificateSection:
					fmt.Printf("  - [%d] KRL_SECTION_CERTIFICATES:\n", secNum)
					var ca string
					if s.CA == nil {
						ca = "any"
					} else {
						ca = sshutil.Fingerprint(s.CA)
					}
					fmt.Println("    * CA:", ca)
					subSecNum := 0
					for _, subSection := range s.Sections {
						switch ss := subSection.(type) {
						case *krl.KRLCertificateSerialList:
							fmt.Printf("        + [%d] KRL_SECTION_CERT_SERIAL_LIST:\n", subSecNum)
							for _, keyID := range *ss {
								fmt.Println("          ", keyID)
							}
						case *krl.KRLCertificateSerialRange:
							fmt.Printf("        + [%d] KRL_SECTION_CERT_SERIAL_RANGE:\n", subSecNum)
							fmt.Println("          ", ss.Min, "-", ss.Max)
						case *krl.KRLCertificateSerialBitmap:
							fmt.Printf("        + [%d] KRL_SECTION_CERT_SERIAL_BITMAP:\n", subSecNum)
							fmt.Println("          ", ss.Offset, "+")
						case *krl.KRLCertificateKeyID:
							fmt.Printf("        + [%d] KRL_SECTION_CERT_KEY_ID:\n", subSecNum)
							for _, keyID := range *ss {
								fmt.Println("          ", keyID)
							}
						default:
							return fmt.Errorf("unsupported certificate section: %T", ss)
						}

						subSecNum++
					}
				case *krl.KRLExplicitKeySection:
					fmt.Printf("  - [%d] KRL_SECTION_EXPLICIT_KEY:\n", secNum)
					fmt.Println("    * keys:")
					for _, key := range *s {
						fmt.Println("        ", sshutil.Fingerprint(key))
					}
				case *krl.KRLFingerprintSection:
					fmt.Printf("  - [%d] KRL_SECTION_FINGERPRINT_SHA1:\n", secNum)
					fmt.Println("    * fingerprints:")
					for _, fp := range *s {
						fmt.Println("        ", hex.EncodeToString(fp[:]))
					}
				case *krl.KRLFingerprintSHA256Section:
					fmt.Printf("  - [%d] KRL_SECTION_FINGERPRINT_SHA256:\n", secNum)
					fmt.Println("    * fingerprints:")
					for _, fp := range *s {
						fmt.Println("        ", hex.EncodeToString(fp[:]))
					}
				default:
					return fmt.Errorf("unsupported section: %T", s)
				}

				secNum++
			}
			return nil
		}

		for _, arg := range args {
			fmt.Println("parse file: ", arg)
			if err := doParse(arg); err != nil {
				fmt.Println("fail: ", err)
			}
			fmt.Println()
		}

		return nil
	},
}
