package agent

import (
	"crypto/sha256"
	"crypto/subtle"
	"errors"
	"fmt"
	"sync"

	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/skotty/libs/askpass"
	"a.yandex-team.ru/security/skotty/robossh/internal/agentkey"
	"a.yandex-team.ru/security/skotty/robossh/internal/keystore"
	"a.yandex-team.ru/security/skotty/robossh/internal/logger"
)

var _ agent.ExtendedAgent = (*Handler)(nil)
var ErrNotImplemented = errors.New("not implemented")
var ErrLocked = errors.New("locked")
var ErrNotLocked = errors.New("not locked")

type Handler struct {
	keys           *keystore.Store
	lifetime       uint32
	lockPassphrase []byte
	mu             sync.Mutex
}

// List returns the identities known to the agent.
func (h *Handler) List() ([]*agent.Key, error) {
	h.mu.Lock()
	defer h.mu.Unlock()

	if h.lockPassphrase != nil {
		logger.Warn("trying to list on locked agent resulted in empty keys list")
		return nil, nil
	}

	if c := h.keys.RemoveExpired(); c > 0 {
		logger.Info("removed expired keys", log.Int("count", c))
	}

	out := make([]*agent.Key, 0, h.keys.Len())
	h.keys.Range(func(key *agentkey.Key) bool {
		if key.IsStale() {
			return true
		}

		out = append(out, key.AgentKey())
		return true
	})

	return out, nil
}

// Sign has the agent sign the data using a protocol 2 key as defined
// in [PROTOCOL.agent] section 2.6.2.
func (h *Handler) Sign(reqKey ssh.PublicKey, data []byte) (*ssh.Signature, error) {
	return h.SignWithFlags(reqKey, data, 0)
}

// SignWithFlags signs like Sign, but allows for additional flags to be sent/received
func (h *Handler) SignWithFlags(reqKey ssh.PublicKey, data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) {
	h.mu.Lock()
	defer h.mu.Unlock()

	if h.lockPassphrase != nil {
		return nil, ErrLocked
	}

	doSign := func(fp string) (*ssh.Signature, error) {
		key, ok := h.keys.Get(fp)
		if !ok {
			return nil, fmt.Errorf("unknown key: %v", fp)
		}

		out, err := h.sign(key, data, flags)
		if err != nil {
			return nil, err
		}

		logger.Info("signed request", log.String("fingerprint", fp), log.String("key_name", key.Name()))
		return out, nil
	}

	fp := ssh.FingerprintSHA256(reqKey)
	out, err := doSign(fp)
	if err != nil {
		logger.Error("sign request failed", log.String("fingerprint", fp))
		return nil, err
	}

	return out, nil
}

// Extension processes a custom extension request. Standard-compliant agents are not
// required to support any extensions, but this method allows agents to implement
// vendor-specific methods or add experimental features. See [PROTOCOL.agent] section 4.7.
func (h *Handler) Extension(_ string, _ []byte) ([]byte, error) {
	h.mu.Lock()
	defer h.mu.Unlock()

	if h.lockPassphrase != nil {
		return nil, ErrLocked
	}

	return nil, agent.ErrExtensionUnsupported
}

// Signers returns signers for all the known keys.
func (h *Handler) Signers() ([]ssh.Signer, error) {
	return nil, ErrNotImplemented
}

// Add adds a private key to the agent.
func (h *Handler) Add(newKey agent.AddedKey) error {
	h.mu.Lock()
	defer h.mu.Unlock()

	if h.lockPassphrase != nil {
		return ErrLocked
	}

	if h.lifetime != 0 && newKey.LifetimeSecs == 0 {
		newKey.LifetimeSecs = h.lifetime
	}

	key, err := agentkey.NewKey(newKey)
	if err != nil {
		return fmt.Errorf("can't parse added ssh key: %w", err)
	}

	if !h.keys.Add(key) {
		logger.Info("skipped adding a key: already exists",
			log.String("name", key.Name()),
			log.String("fingerprint", key.Fingerprint()),
		)

		return nil
	}

	logger.Info("key added",
		log.String("name", key.Name()),
		log.String("fingerprint", key.Fingerprint()),
	)
	return nil
}

// Remove removes all identities with the given public key.
func (h *Handler) Remove(reqKey ssh.PublicKey) error {
	h.mu.Lock()
	defer h.mu.Unlock()

	if h.lockPassphrase != nil {
		return ErrLocked
	}

	fp := ssh.FingerprintSHA256(reqKey)
	if !h.keys.Remove(fp) {
		return fmt.Errorf("unknown key: %v", fp)
	}

	logger.Info("key removed",
		log.String("fingerprint", fp),
	)

	return nil
}

// RemoveAll removes all identities.
func (h *Handler) RemoveAll() error {
	h.mu.Lock()
	defer h.mu.Unlock()

	if h.lockPassphrase != nil {
		return ErrLocked
	}

	h.keys.RemoveAll()
	return nil
}

// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
func (h *Handler) Lock(passphrase []byte) error {
	h.mu.Lock()
	defer h.mu.Unlock()

	if h.lockPassphrase != nil {
		return ErrLocked
	}

	h.lockPassphrase = encodePassphrase(passphrase)
	logger.Info("agent locked")
	return nil
}

// Unlock undoes the effect of Lock
func (h *Handler) Unlock(passphrase []byte) error {
	h.mu.Lock()
	defer h.mu.Unlock()

	if h.lockPassphrase == nil {
		return ErrNotLocked
	}

	passphrase = encodePassphrase(passphrase)
	if subtle.ConstantTimeCompare(passphrase, h.lockPassphrase) != 1 {
		return errors.New("incorrect passphrase")
	}

	h.lockPassphrase = nil
	logger.Info("agent unlocked")
	return nil
}

func (h *Handler) sign(key *agentkey.Key, data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) {
	if key.MustConfirm() {
		msg := fmt.Sprintf("Allow use of key %s?\nKey fingerprint %s.", key.Comment(), key.Fingerprint())
		allow, err := askpass.NewClient().Confirm(msg)
		if err != nil {
			return nil, fmt.Errorf("unable to confirm key usage: %w", err)
		}

		if !allow {
			return nil, errors.New("user not allowed to use this key")
		}
	}

	return key.Sign(data, flags)
}

func encodePassphrase(passphrase []byte) []byte {
	h := sha256.New()
	h.Write(passphrase)
	return h.Sum(nil)
}
