package keychain

import (
	"crypto"
	"crypto/x509"
	"fmt"
	"sync"

	"a.yandex-team.ru/security/skotty/libs/skotty"
	"a.yandex-team.ru/security/skotty/skotty/internal/certgen"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring/keychain/internal/chaintypes"
	"a.yandex-team.ru/security/skotty/skotty/internal/pinstore"
	"a.yandex-team.ru/security/skotty/skotty/pkg/softattest"
)

const (
	TokenType = skotty.TokenTypeSoft
	Name      = "keychain"
	HumanName = "Keychain"
)

var supportedKeyTypes = []keyring.KeyPurpose{
	keyring.KeyPurposeSudo,
	keyring.KeyPurposeSecure,
	keyring.KeyPurposeInsecure,
	keyring.KeyPurposeLegacy,
}

var _ keyring.Keyring = (*Keychain)(nil)
var _ keyring.Tx = (*Tx)(nil)

type Keychain struct {
	svc  chaintypes.Service
	keys map[keyring.KeyPurpose]chaintypes.KeyPair
	mu   sync.Mutex
}

type Tx struct {
	k *Keychain
}

func NewKeychain(collection string) (keyring.Keyring, error) {
	svc, err := newKeyChainService(collection)
	if err != nil {
		return nil, err
	}

	return &Keychain{
		svc:  svc,
		keys: make(map[keyring.KeyPurpose]chaintypes.KeyPair),
	}, nil
}

func (k *Keychain) TokenType() skotty.TokenType {
	return TokenType
}

func (k *Keychain) Name() string {
	return Name
}

func (k *Keychain) HumanName() string {
	return HumanName
}

func (k *Keychain) SupportedKeyTypes() []keyring.KeyPurpose {
	return supportedKeyTypes
}

func (k *Keychain) IsTouchableKey(_ keyring.KeyPurpose) bool {
	return false
}

func (k *Keychain) Serial() (string, error) {
	return softattest.SharedAttestator().TokenSerial()
}

func (k *Keychain) Tx() (keyring.Tx, error) {
	return &Tx{
		k: k,
	}, nil
}

func (k *Keychain) PinStoreOpts() []pinstore.Option {
	return nil
}

func (k *Keychain) Close() {
	k.svc.Close()
}

func (t *Tx) Serial() (string, error) {
	return softattest.SharedAttestator().TokenSerial()
}

func (t *Tx) TokenType() skotty.TokenType {
	return TokenType
}

func (t *Tx) Name() string {
	return Name
}

func (t *Tx) HumanName() string {
	return HumanName
}

func (t *Tx) SupportedKeyTypes() []keyring.KeyPurpose {
	return supportedKeyTypes
}

func (t *Tx) AttestationCertificate() (*x509.Certificate, error) {
	return softattest.SharedAttestator().Certificate()
}

func (t *Tx) Attest(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	keyPair, err := t.keyPair(keyType)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch keypair: %w", err)
	}

	pubKey, err := x509.ParseCertificate(keyPair.PubKey)
	if err != nil {
		return nil, fmt.Errorf("failed to parse certificate: %w", err)
	}

	return softattest.SharedAttestator().Attest(pubKey, softattest.PINPolicyOnce, softattest.TouchPolicyNever)
}

func (t *Tx) RenewCertificate(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	return t.GenCertificate(keyType)
}

func (t *Tx) GenCertificate(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	var keyPair chaintypes.KeyPair
	var err error
	keyPair.PubKey, keyPair.PrivKey, err = certgen.GenCertificate(keyType)
	if err != nil {
		return nil, fmt.Errorf("failed to generate new keypair: %w", err)
	}

	err = t.saveKeyPair(keyType, keyPair)
	if err != nil {
		return nil, fmt.Errorf("failed to save keypair: %w", err)
	}

	return x509.ParseCertificate(keyPair.PubKey)
}

func (t *Tx) SetCertificate(keyType keyring.KeyPurpose, crt *x509.Certificate) error {
	keyPair, err := t.keyPair(keyType)
	if err != nil {
		return fmt.Errorf("failed to fetch keypair: %w", err)
	}

	keyPair.PubKey = crt.Raw
	return t.saveKeyPair(keyType, keyPair)
}

func (t *Tx) Certificate(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	keyPair, err := t.keyPair(keyType)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch keypair: %w", err)
	}

	return x509.ParseCertificate(keyPair.PubKey)
}

func (t *Tx) Signer(keyType keyring.KeyPurpose) (crypto.Signer, error) {
	keyPair, err := t.keyPair(keyType)
	if err != nil {
		return nil, fmt.Errorf("failed to fetch keypair: %w", err)
	}

	switch keyType {
	case keyring.KeyPurposeLegacy:
		return x509.ParsePKCS1PrivateKey(keyPair.PrivKey)
	default:
		return x509.ParseECPrivateKey(keyPair.PrivKey)
	}
}

func (t *Tx) keyPair(keyType keyring.KeyPurpose) (chaintypes.KeyPair, error) {
	t.k.mu.Lock()
	defer t.k.mu.Unlock()

	if keyPair, ok := t.k.keys[keyType]; ok {
		return keyPair, nil
	}

	sess, err := t.k.svc.Session()
	if err != nil {
		return chaintypes.KeyPair{}, fmt.Errorf("failed to open keychain service session: %w", err)
	}
	defer sess.Close()

	keyPair, err := sess.FetchKeyPair(keyType)
	if err != nil {
		return chaintypes.KeyPair{}, fmt.Errorf("failed to fetch keypair from keychain service: %w", err)
	}

	t.k.keys[keyType] = keyPair
	return keyPair, nil
}

func (t *Tx) saveKeyPair(keyType keyring.KeyPurpose, keyPair chaintypes.KeyPair) error {
	t.k.mu.Lock()
	defer t.k.mu.Unlock()

	sess, err := t.k.svc.Session()
	if err != nil {
		return fmt.Errorf("failed to open keychain service session: %w", err)
	}
	defer sess.Close()

	err = sess.SaveKeyPair(keyType, keyPair)
	if err != nil {
		return fmt.Errorf("failed to save keypair into keychainservice: %w", err)
	}

	t.k.keys[keyType] = keyPair
	return nil
}

func (t *Tx) Close() {}
