package scenario

import (
	"bytes"
	"crypto/x509"
	"fmt"
	"strings"

	"golang.org/x/crypto/ssh"

	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
	"a.yandex-team.ru/security/skotty/skotty/pkg/osutil"
)

type Certificates struct {
	*Scenario
}

func NewCertificates(opts ...Option) *Certificates {
	return &Certificates{
		Scenario: NewScenario(opts...),
	}
}

func (s *Certificates) GenKeys(kr keyring.Keyring, next func([]TokenCert) error) error {
	s.LogInfo("generate private keys")

	out := make([]TokenCert, 0, len(supportedKeys))
	err := keyringCall(kr, func(tx keyring.Tx) error {
		tokenKeyTypes := make(map[keyring.KeyPurpose]struct{}, len(supportedKeys))
		for _, keyType := range tx.SupportedKeyTypes() {
			tokenKeyTypes[keyType] = struct{}{}
		}

		for _, keyType := range supportedKeys {
			if _, ok := tokenKeyTypes[keyType]; !ok {
				continue
			}

			cert, err := tx.GenCertificate(keyType)
			if err != nil {
				return fmt.Errorf("generate certificate for key %q: %w", keyType, err)
			}

			out = append(out, TokenCert{
				Cert:    cert,
				KeyType: keyType,
			})
		}

		return nil
	})
	if err != nil {
		return err
	}

	return next(out)
}

func (s *Certificates) AuthorizedKeys(keyringName string, keys []keyring.KeyPurpose, next func([]AuthorizedKey) error) error {
	out := make([]AuthorizedKey, 0, len(keys))
	for _, purpose := range keys {
		comment := fmt.Sprintf("Skotty key %s on %s", purpose, keyringName)

		switch purpose {
		case keyring.KeyPurposeSudo, keyring.KeyPurposeRenew:
			// sudo/renew key doesn't needed on staff
			continue
		case keyring.KeyPurposeLegacy:
			pubBytes, err := s.pubStore.ReadKey(keyringName, purpose)
			if err != nil {
				return fmt.Errorf("can't read ssh pub for key type %s: %w", purpose, err)
			}

			out = append(out, AuthorizedKey{
				Purpose: purpose.String(),
				Blob:    fmt.Sprintf("%s # %s", string(pubBytes), comment),
			})
		default:
			pubBytes, err := s.pubStore.ReadKey(keyringName, purpose)
			if err != nil {
				return fmt.Errorf("can't read ssh pub for key type %s: %w", purpose, err)
			}

			pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pubBytes)
			if err != nil {
				return fmt.Errorf("can't parse ssh pub for key type %s: %s", purpose, err)
			}

			pubCert, ok := pubKey.(*ssh.Certificate)
			if !ok {
				return fmt.Errorf("unexpected ssh pub type for key type %s: %T", purpose, pubKey)
			}

			ca := bytes.TrimSpace(ssh.MarshalAuthorizedKey(pubCert.SignatureKey))
			principals := strings.Join(pubCert.ValidPrincipals, ",")
			out = append(out, AuthorizedKey{
				Purpose: purpose.String(),
				Blob:    fmt.Sprintf("cert-authority,principals=%q %s # %s", principals, string(ca), comment),
			})
		}
	}

	return next(out)
}

func (s *Certificates) CollectCerts(kr keyring.Keyring, next func([]*x509.Certificate) error) error {
	s.LogInfo("collect token certificates")

	var out []*x509.Certificate
	err := keyringCall(kr, func(tx keyring.Tx) error {
		for _, keyType := range kr.SupportedKeyTypes() {
			cert, err := tx.Certificate(keyType)
			if err != nil {
				continue
			}

			out = append(out, cert)
		}
		return nil
	})

	if err != nil {
		return err
	}

	return next(out)
}

func (s *Certificates) CleanupCerts(certs []*x509.Certificate, next func() error) error {
	s.LogInfo("cleanup %d old certificates", len(certs))

	if err := osutil.RemoveFromCertStorage(certs...); err != nil {
		return err
	}

	return next()
}
