package filering

import (
	"bytes"
	"crypto"
	"crypto/x509"
	"encoding/pem"
	"errors"
	"fmt"
	"os"
	"path/filepath"

	"a.yandex-team.ru/security/skotty/libs/skotty"
	"a.yandex-team.ru/security/skotty/skotty/internal/certgen"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
	"a.yandex-team.ru/security/skotty/skotty/internal/passutil"
	"a.yandex-team.ru/security/skotty/skotty/internal/pinstore"
	"a.yandex-team.ru/security/skotty/skotty/pkg/softattest"
)

const (
	TokenType = skotty.TokenTypeSoft
	Name      = "files"
	HumanName = "Files"
)

var supportedKeyTypes = []keyring.KeyPurpose{
	keyring.KeyPurposeSudo,
	keyring.KeyPurposeSecure,
	keyring.KeyPurposeInsecure,
	keyring.KeyPurposeLegacy,
}

var _ keyring.Keyring = (*FileRing)(nil)
var _ keyring.Tx = (*Tx)(nil)

type FileRing struct {
	passphrase pinstore.Provider
	keysPath   string
}

type Tx struct {
	passphrase pinstore.Provider
	keysPath   string
}

type KeyPair struct {
	PrivKey []byte
	PubKey  []byte
}

func NewFilering(basePath string, passphrase pinstore.Provider) (keyring.Keyring, error) {
	if _, err := os.Stat(basePath); err != nil {
		if !os.IsNotExist(err) {
			return nil, fmt.Errorf("can't use base path: %w", err)
		}

		err = os.MkdirAll(basePath, 0700)
		if err != nil {
			return nil, fmt.Errorf("failed to create keyring")
		}
	}

	return &FileRing{
		passphrase: passphrase,
		keysPath:   basePath,
	}, nil
}

func (k *FileRing) TokenType() skotty.TokenType {
	return TokenType
}

func (k *FileRing) Name() string {
	return Name
}

func (k *FileRing) HumanName() string {
	return HumanName
}

func (k *FileRing) SupportedKeyTypes() []keyring.KeyPurpose {
	return supportedKeyTypes
}

func (k *FileRing) IsTouchableKey(_ keyring.KeyPurpose) bool {
	return false
}

func (k *FileRing) Serial() (string, error) {
	return softattest.SharedAttestator().TokenSerial()
}

func (k *FileRing) PinStoreOpts() []pinstore.Option {
	return pinstoreOpts()
}

func (k *FileRing) Tx() (keyring.Tx, error) {
	return &Tx{
		passphrase: k.passphrase,
		keysPath:   k.keysPath,
	}, nil
}

func (k *FileRing) Close() {}

func (t *Tx) Serial() (string, error) {
	return softattest.SharedAttestator().TokenSerial()
}

func (t *Tx) TokenType() skotty.TokenType {
	return TokenType
}

func (t *Tx) Name() string {
	return Name
}

func (t *Tx) HumanName() string {
	return HumanName
}

func (t *Tx) SupportedKeyTypes() []keyring.KeyPurpose {
	return supportedKeyTypes
}

func (t *Tx) AttestationCertificate() (*x509.Certificate, error) {
	return softattest.SharedAttestator().Certificate()
}

func (t *Tx) Attest(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	keyPair, err := t.getCert(keyType)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch keypair: %w", err)
	}

	pubKey, err := x509.ParseCertificate(keyPair.PubKey)
	if err != nil {
		return nil, fmt.Errorf("failed to parse certificate: %w", err)
	}

	return softattest.SharedAttestator().Attest(pubKey, softattest.PINPolicyOnce, softattest.TouchPolicyNever)
}

func (t *Tx) RenewCertificate(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	return t.GenCertificate(keyType)
}

func (t *Tx) GenCertificate(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	var keyPair KeyPair
	var err error
	keyPair.PubKey, keyPair.PrivKey, err = certgen.GenCertificate(keyType)
	if err != nil {
		return nil, fmt.Errorf("failed to generate new keypair: %w", err)
	}

	err = t.saveCert(keyType, keyPair)
	if err != nil {
		return nil, fmt.Errorf("failed to save keypair: %w", err)
	}

	return x509.ParseCertificate(keyPair.PubKey)
}

func (t *Tx) SetCertificate(keyType keyring.KeyPurpose, crt *x509.Certificate) error {
	keyPair, err := t.getCert(keyType)
	if err != nil {
		return fmt.Errorf("failed to fetch keypair: %w", err)
	}

	keyPair.PubKey = crt.Raw
	return t.saveCert(keyType, keyPair)
}

func (t *Tx) Certificate(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	keyPair, err := t.getCert(keyType)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch keypair: %w", err)
	}

	return x509.ParseCertificate(keyPair.PubKey)
}

func (t *Tx) Signer(keyType keyring.KeyPurpose) (crypto.Signer, error) {
	keyPair, err := t.getCert(keyType)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch keypair: %w", err)
	}

	switch keyType {
	case keyring.KeyPurposeLegacy:
		return x509.ParsePKCS1PrivateKey(keyPair.PrivKey)
	default:
		return x509.ParseECPrivateKey(keyPair.PrivKey)
	}
}

func (t *Tx) getCert(keyType keyring.KeyPurpose) (KeyPair, error) {
	var out KeyPair
	keyBytes, err := os.ReadFile(t.keyPath(keyType))
	if err != nil {
		return out, fmt.Errorf("failed to read certificate: %w", err)
	}

	var pemBlock *pem.Block
	for {
		pemBlock, keyBytes = pem.Decode(keyBytes)
		if pemBlock == nil {
			break
		}

		switch pemBlock.Type {
		case "CERTIFICATE":
			if len(out.PubKey) > 0 {
				return out, errors.New("duplicate 'CERTIFICATE' block")
			}

			out.PubKey = pemBlock.Bytes
		case "SKOTTY PRIVATE KEY":
			if len(out.PrivKey) > 0 {
				return out, errors.New("duplicate 'SKOTTY PRIVATE KEY' block")
			}

			err = t.doWithPassphrase(func(passphrase string) error {
				out.PrivKey, err = passutil.Decrypt(pemBlock.Bytes, []byte(passphrase))
				if err != nil {
					return fmt.Errorf("failed to decrypt private key: %w", err)
				}

				return nil
			})

			if err != nil {
				return out, err
			}
		default:
			return out, fmt.Errorf("unexpected PEM block type: %s", pemBlock.Type)
		}
	}

	if len(out.PubKey) == 0 {
		return out, errors.New("no certificate found")
	}

	if len(out.PrivKey) == 0 {
		return out, errors.New("no priv key found")
	}

	return out, nil
}

func (t *Tx) saveCert(keyType keyring.KeyPurpose, keyPair KeyPair) error {
	var pemBytes bytes.Buffer
	err := pem.Encode(&pemBytes, &pem.Block{
		Type:  "CERTIFICATE",
		Bytes: keyPair.PubKey,
	})
	if err != nil {
		return fmt.Errorf("failed to encode certificate: %w", err)
	}

	var encodedPriv []byte
	err = t.doWithPassphrase(func(passphrase string) error {
		encoded, err := passutil.Encrypt(keyPair.PrivKey, []byte(passphrase))
		if err != nil {
			return pinstore.Permanent(fmt.Errorf("failed to encrypt private key: %w", err))
		}

		encodedPriv = encoded
		return nil
	})

	if err != nil {
		return err
	}

	err = pem.Encode(&pemBytes, &pem.Block{
		Type:  "SKOTTY PRIVATE KEY",
		Bytes: encodedPriv,
	})
	if err != nil {
		return fmt.Errorf("failed to encode private key: %w", err)
	}

	return os.WriteFile(t.keyPath(keyType), pemBytes.Bytes(), 0o600)
}

func (t *Tx) Close() {}

func (t *Tx) keyPath(purpose keyring.KeyPurpose) string {
	return filepath.Join(t.keysPath, fileName(purpose))
}

func (t *Tx) doWithPassphrase(validator pinstore.PassphraseValidator) error {
	_, err := t.passphrase.GetPassphrase(
		func(passphrase string) error {
			if len(passphrase) < 8 {
				return errors.New("passphrase must be at least 8 characters")
			}

			return validator(passphrase)
		},
		pinstoreOpts()...,
	)

	return err
}

func fileName(purpose keyring.KeyPurpose) string {
	return fmt.Sprintf("%s-cert.pem", purpose)
}

func pinstoreOpts() []pinstore.Option {
	serial, err := softattest.SharedAttestator().TokenSerial()
	if err != nil {
		panic(fmt.Sprintf("unable to get serial: %v", err))
	}

	return []pinstore.Option{
		pinstore.WithSerial(serial),
		pinstore.WithDescription("Please enter the passphrase to unlock files keyring"),
		pinstore.WithSync(false),
	}
}
