package yubiring

import (
	"crypto"
	"crypto/x509"
	"errors"
	"fmt"
	"strconv"
	"sync"
	"sync/atomic"

	"a.yandex-team.ru/security/libs/go/pcsc"
	"a.yandex-team.ru/security/libs/go/piv"
	"a.yandex-team.ru/security/skotty/libs/skotty"
	"a.yandex-team.ru/security/skotty/skotty/internal/certgen"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
	"a.yandex-team.ru/security/skotty/skotty/internal/pinstore"
	"a.yandex-team.ru/security/skotty/skotty/internal/yubikey"
)

const (
	TokenType = skotty.TokenTypeYubikey
	Name      = "yubikey"
	HumanName = "Yubikey"
)

var supportedKeyTypes = []keyring.KeyPurpose{
	keyring.KeyPurposeSudo,
	keyring.KeyPurposeSecure,
	keyring.KeyPurposeInsecure,
	keyring.KeyPurposeLegacy,
	keyring.KeyPurposeRenew,
}

var _ keyring.Keyring = (*YubiRing)(nil)
var _ keyring.Tx = (*Tx)(nil)

type YubiRing struct {
	inited uint32
	mu     sync.Mutex
	card   yubikey.Card
	serial uint32
	pin    pinstore.Provider
	pubs   pubCache
}

type Tx struct {
	yk   *yubikey.Yubikey
	pubs *pubCache
	pin  pinstore.Provider
}

func NewYubiring(serial uint32, pin pinstore.Provider) (*YubiRing, error) {
	return &YubiRing{
		serial: serial,
		pin:    pin,
		pubs: pubCache{
			cache: make(map[keyring.KeyPurpose]crypto.PublicKey),
		},
	}, nil
}

func (s *YubiRing) reInit() error {
	_ = atomic.CompareAndSwapUint32(&s.inited, 1, 0)
	return s.lazyInit()
}

func (s *YubiRing) lazyInit() error {
	if atomic.LoadUint32(&s.inited) == 1 {
		return nil
	}

	s.mu.Lock()
	defer s.mu.Unlock()

	if atomic.LoadUint32(&s.inited) == 1 {
		return nil
	}

	yk, err := yubikey.OpenBySerial(s.serial)
	if err != nil {
		return err
	}
	defer yk.Close()

	s.card = yk.Card
	atomic.StoreUint32(&s.inited, 1)
	return nil
}

func (s *YubiRing) TokenType() skotty.TokenType {
	return TokenType
}

func (s *YubiRing) Name() string {
	return Name
}

func (s *YubiRing) HumanName() string {
	return HumanName
}

func (s *YubiRing) SupportedKeyTypes() []keyring.KeyPurpose {
	return supportedKeyTypes
}

func (s *YubiRing) IsTouchableKey(purpose keyring.KeyPurpose) bool {
	policy := purpose.TouchPolicy()
	return policy == keyring.TouchPolicyCached || policy == keyring.TouchPolicyAlways
}

func (s *YubiRing) Serial() (string, error) {
	return strconv.Itoa(int(s.serial)), nil
}

func (s *YubiRing) PinStoreOpts() []pinstore.Option {
	return pinstoreOpts(s.serial)
}

func (s *YubiRing) Tx() (keyring.Tx, error) {
	if err := s.lazyInit(); err != nil {
		return nil, fmt.Errorf("can't initialize yubikey: %w", err)
	}

	yk, err := yubikey.Open(s.card)
	if err != nil {
		if !errors.Is(err, pcsc.ErrUnknownReader) {
			return nil, fmt.Errorf("can't open yubikey %q: %w", s.card, err)
		}

		// probably yubikey was re-inserted with different reader name. Let's try one more time
		if err := s.reInit(); err != nil {
			return nil, fmt.Errorf("can't initialize yubikey: %w", err)
		}

		yk, err = yubikey.Open(s.card)
		if err != nil {
			return nil, fmt.Errorf("can't open yubikey %q: %w", s.card, err)
		}
	}

	return &Tx{
		yk:   yk,
		pubs: &s.pubs,
		pin:  s.pin,
	}, nil
}

func (s *YubiRing) Close() {}

func (t *Tx) Serial() (string, error) {
	return strconv.Itoa(int(t.yk.Serial)), nil
}

func (t *Tx) TokenType() skotty.TokenType {
	return TokenType
}

func (t *Tx) Name() string {
	return Name
}

func (t *Tx) HumanName() string {
	return HumanName
}

func (t *Tx) SupportedKeyTypes() []keyring.KeyPurpose {
	return supportedKeyTypes
}

func (t *Tx) AttestationCertificate() (*x509.Certificate, error) {
	return t.yk.AttestationCertificate()
}

func (t *Tx) Attest(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	slot, err := slotForPurpose(keyType)
	if err != nil {
		return nil, fmt.Errorf("can't determine yubikey slot for keytype %q: %w", keyType, err)
	}

	return t.yk.Attest(slot)
}

func (t *Tx) RenewCertificate(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	slot, err := slotForPurpose(keyType)
	if err != nil {
		return nil, fmt.Errorf("can't determine yubikey slot for keytype %q: %w", keyType, err)
	}

	// use old private key due to attacker can't export them
	if crt, err := t.yk.Certificate(slot); err == nil {
		return crt, nil
	}

	return t.GenCertificate(keyType)
}

func (t *Tx) GenCertificate(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	slot, err := slotForPurpose(keyType)
	if err != nil {
		return nil, fmt.Errorf("can't determine yubikey slot for keytype %q: %w", keyType, err)
	}

	key, err := pivKeyForPurpose(keyType)
	if err != nil {
		return nil, fmt.Errorf("can't determine yubikey piv key for keytype %q: %w", keyType, err)
	}

	var cert *x509.Certificate
	err = t.doWithPIN(func(pin string) (err error) {
		cert, err = t.yk.GenCertificate(slot, pin, yubikey.CertRequest{
			CommonName: certgen.GenCommonName(keyType),
			Key:        key,
		})
		return
	})
	return cert, err
}

func (t *Tx) SetCertificate(keyType keyring.KeyPurpose, crt *x509.Certificate) error {
	slot, err := slotForPurpose(keyType)
	if err != nil {
		return fmt.Errorf("can't determine yubikey slot for keytype %q: %w", keyType, err)
	}

	err = t.doWithPIN(func(pin string) error {
		return t.yk.SetCertificate(slot, pin, crt)
	})

	if err != nil {
		return fmt.Errorf("failed to update certificate on yubikey: %w", err)
	}

	return nil
}

func (t *Tx) Certificate(keyType keyring.KeyPurpose) (*x509.Certificate, error) {
	slot, err := slotForPurpose(keyType)
	if err != nil {
		return nil, fmt.Errorf("can't determine yubikey slot for keytype %q: %w", keyType, err)
	}

	return t.yk.Certificate(slot)
}

func (t *Tx) Signer(keyType keyring.KeyPurpose) (crypto.Signer, error) {
	slot, err := slotForPurpose(keyType)
	if err != nil {
		return nil, fmt.Errorf("can't determine yubikey slot for keytype %q: %w", keyType, err)
	}

	pub := t.pubs.Get(keyType)
	if pub == nil {
		cert, err := t.yk.Certificate(slot)
		if err != nil {
			return nil, fmt.Errorf("unable to get certificate for keytype %q: %w", keyType, err)
		}

		pub = cert.PublicKey
		t.pubs.Set(keyType, pub)
	}

	keyAuth := piv.KeyAuth{
		PINPolicy:   piv.PINPolicyAlways,
		PINProvider: t.doWithPIN,
	}
	signer, err := t.yk.PrivateKey(slot, keyAuth, pub)
	if err != nil {
		return nil, err
	}

	return &Signer{Signer: signer}, nil
}

func (t *Tx) Close() {
	t.yk.Close()
}

func (t *Tx) doWithPIN(validator func(pin string) error) error {
	_, err := t.pin.GetPIN(
		func(pin string) error {
			if len(pin) < 6 || len(pin) > 8 {
				return errors.New("PIN must be 6-8 characters long")
			}

			err := validator(pin)
			if err == nil {
				return nil
			}

			var authErr *yubikey.AuthErr
			if errors.As(err, &authErr) {
				if authErr.Retries == 0 {
					return pinstore.Permanent(err)
				}

				return err
			}

			return pinstore.Permanent(err)
		},
		pinstoreOpts(t.yk.Serial)...,
	)

	return err
}

func IsAvailable() (bool, error) {
	cards, err := yubikey.Cards()
	return len(cards) > 0, err
}

func slotForPurpose(purpose keyring.KeyPurpose) (yubikey.Slot, error) {
	var keyID uint32
	switch purpose {
	case keyring.KeyPurposeRenew:
		keyID = 0x91
	case keyring.KeyPurposeSecure:
		keyID = 0x92
	case keyring.KeyPurposeInsecure:
		keyID = 0x93
	case keyring.KeyPurposeSudo:
		keyID = 0x94
	case keyring.KeyPurposeLegacy:
		keyID = 0x95
	default:
		return yubikey.Slot{}, fmt.Errorf("unsupported key purpose: %s", purpose)
	}

	return yubikey.SlotFromKeyID(keyID)
}

func pivKeyForPurpose(purpose keyring.KeyPurpose) (piv.Key, error) {
	out := piv.Key{
		PINPolicy: piv.PINPolicyAlways,
	}

	switch purpose.Algo() {
	case keyring.AlgorithmEC256:
		out.Algorithm = piv.AlgorithmEC256
	case keyring.AlgorithmRSA1024:
		out.Algorithm = piv.AlgorithmRSA2048
	default:
		return out, fmt.Errorf("unsupported algo for key type: %s", purpose.Algo())
	}

	switch purpose.TouchPolicy() {
	case keyring.TouchPolicyNever:
		out.TouchPolicy = piv.TouchPolicyNever
	case keyring.TouchPolicyCached:
		out.TouchPolicy = piv.TouchPolicyCached
	case keyring.TouchPolicyAlways:
		out.TouchPolicy = piv.TouchPolicyAlways
	default:
		return out, fmt.Errorf("unsupported touch policy for key type: %s", purpose.TouchPolicy())
	}

	return out, nil
}

func pinstoreOpts(serial uint32) []pinstore.Option {
	serialStr := strconv.Itoa(int(serial))
	return []pinstore.Option{
		pinstore.WithSerial(serialStr),
		pinstore.WithDescription("Please enter the PIN to unlock Yubikey #" + serialStr),
	}
}
