package agent

import (
	"context"
	"crypto/sha256"
	"crypto/subtle"
	"errors"
	"fmt"
	"math"
	"os"
	"sort"
	"sync"
	"time"

	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"
	"google.golang.org/protobuf/proto"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/ctxlog"
	"a.yandex-team.ru/security/skotty/skotty/internal/agent/agentkey"
	"a.yandex-team.ru/security/skotty/skotty/internal/config"
	"a.yandex-team.ru/security/skotty/skotty/internal/confirm"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
	"a.yandex-team.ru/security/skotty/skotty/internal/ui"
	"a.yandex-team.ru/security/skotty/skotty/pkg/psudo"
)

const (
	SudoExtensionName      = "skotty-sudo"
	SudoCheckExtensionName = "skotty-sudo-check"
	NotifyPostpone         = 500 * time.Millisecond
	NotifyTTL              = 5 * time.Second
	ErrNotifyTTL           = 10 * time.Second
	AuthErrBreak           = 5 * time.Second
)

var _ Agent = (*RealAgent)(nil)

type RealAgent struct {
	listener
	notifyUsage     bool
	notifier        ui.Notifier
	confirm         confirm.Confirmator
	keysOrder       []agentkey.KeyKind
	conf            config.Socket
	store           *KeyHolder
	sshKeys         []agentkey.Key
	sudoKey         agentkey.Key
	lockPassphrase  []byte
	mu              sync.Mutex
	log             log.Logger
	lastAuthErrTime time.Time
}

type RealAgentParams struct {
	store   *KeyHolder
	confirm confirm.Confirmator
	notify  ui.Notifier
	sock    config.Socket
	logger  log.Logger
}

func newRealAgent(params RealAgentParams) (*RealAgent, error) {
	l := params.logger.WithName(params.sock.Name)
	out := &RealAgent{
		notifier:    params.notify,
		confirm:     params.confirm,
		notifyUsage: params.sock.NotifyUsage,
		conf:        params.sock,
		store:       params.store,
		log:         l,
	}

	if len(params.sock.KeysOrder) > 0 {
		out.keysOrder = params.sock.KeysOrder
	} else {
		out.keysOrder = []agentkey.KeyKind{
			agentkey.KeyKindAdded,
			agentkey.KeyKindInsecure,
			agentkey.KeyKindLegacy,
			agentkey.KeyKindSecure,
		}
	}

	out.listener = listener{
		uid:     os.Getuid(),
		kind:    params.sock.Kind,
		addr:    params.sock.Path,
		handler: out,
		log:     l,
	}
	return out, out.reloadPersistentKeys()
}

func (a *RealAgent) Name() string {
	return a.conf.Name
}

func (a *RealAgent) ReloadKeys() {
	a.mu.Lock()
	defer a.mu.Unlock()

	a.log.Info("reload keys")

	// first - filter persistent keys
	last := 0
	for _, key := range a.sshKeys {
		if key.Persistent() {
			a.log.Info("remove key", log.String("name", key.Name()), log.String("fingerprint", key.Fingerprint()))
			continue
		}

		a.sshKeys[last] = key
		last++
	}
	a.sshKeys = a.sshKeys[:last]

	// now we can reload it
	if err := a.reloadPersistentKeys(); err != nil {
		a.log.Error("failed to reload agent keys", log.Error(err))
		return
	}

	a.log.Info("keys reloaded")
}

func (a *RealAgent) reloadPersistentKeys() error {
	a.sudoKey = nil

	confirmNeeded := false
	switch a.conf.Confirm {
	case config.ConfirmKindAny, config.ConfirmKindKeyring:
		confirmNeeded = true
	}
	for _, keyType := range a.conf.Keys {
		key, err := agentkey.NewKeyringKey(a.store.Keyring, a.store.PubStore, keyType, confirmNeeded)
		if err != nil {
			return fmt.Errorf("failed to load %q key: %w", keyType, err)
		}

		switch keyType {
		case keyring.KeyPurposeSudo:
			if a.sudoKey != nil {
				return fmt.Errorf("duplicate sudo key: %q, %q", a.sudoKey.Name(), keyType)
			}

			a.log.Info("adding sudo key",
				log.String("name", key.Name()),
				log.String("fingerprint", key.Fingerprint()),
			)
			a.sudoKey = key
		case keyring.KeyPurposeRenew:
			// not needed
		default:
			a.log.Info("adding key",
				log.String("name", key.Name()),
				log.String("fingerprint", key.Fingerprint()),
			)
			a.sshKeys = append(a.sshKeys, key)
		}
	}

	a.reorderKeys()
	return nil
}

func (a *RealAgent) reorderKeys() {
	orders := make(map[agentkey.KeyKind]int, len(a.keysOrder))
	for i, ko := range a.keysOrder {
		orders[ko] = i * 1000
	}

	keyOrder := func(k agentkey.KeyKind) int {
		order, ok := orders[k]
		if !ok {
			return math.MaxInt32
		}
		return order
	}

	sort.SliceStable(a.sshKeys, func(i, j int) bool {
		return keyOrder(a.sshKeys[i].Kind()) < keyOrder(a.sshKeys[j].Kind())
	})
}

// List returns the identities known to the agent.
func (a *RealAgent) List(ctx context.Context) ([]*agent.Key, error) {
	a.mu.Lock()
	defer a.mu.Unlock()

	if a.lockPassphrase != nil {
		ctxlog.Warn(ctx, a.log, "trying to list on locked agent")
		return nil, nil
	}

	out := make([]*agent.Key, len(a.sshKeys))
	var notifyMsg string
	for i, key := range a.sshKeys {
		if key.IsExpired() {
			validBefore := key.ValidBefore().Format(time.RFC822)
			ctxlog.Warn(ctx, a.log, "listed expired certificate",
				log.String("name", key.Name()),
				log.String("fingerprint", key.Fingerprint()),
				log.String("valid_before", validBefore),
			)

			if key.IsOurKey() {
				notifyMsg = fmt.Sprintf("certificate %q expired at %s, you must to renew them", key.Name(), validBefore)
			}
		}

		out[i] = key.AgentKey()
	}

	if notifyMsg != "" {
		_ = a.notifier.NotifyAndForget(ui.NotificationKindWarning, notifyMsg, ErrNotifyTTL)
	}

	return out, nil
}

// Sign has the agent sign the data using a protocol 2 key as defined
// in [PROTOCOL.agent] section 2.6.2.
func (a *RealAgent) Sign(ctx context.Context, reqKey ssh.PublicKey, data []byte) (*ssh.Signature, error) {
	return a.SignWithFlags(ctx, reqKey, data, 0)
}

// SignWithFlags signs like Sign, but allows for additional flags to be sent/received
func (a *RealAgent) SignWithFlags(ctx context.Context, reqKey ssh.PublicKey, data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) {
	a.mu.Lock()
	defer a.mu.Unlock()

	if a.lockPassphrase != nil {
		return nil, errors.New(".Sign method is not allowed on agent locked")
	}

	fp := ssh.FingerprintLegacyMD5(reqKey)
	for _, key := range a.sshKeys {
		if key.LegacyFingerprint() != fp {
			continue
		}

		out, err := a.sshSign(ctx, key, data, flags)
		if err != nil {
			var msg string
			if nestedErr := errors.Unwrap(err); nestedErr != nil {
				msg = fmt.Sprintf("Key %q sign fail: %s", key.Name(), nestedErr)
			} else {
				msg = fmt.Sprintf("Key %q sign fail: %s", key.Name(), err)
			}

			_ = a.notifier.NotifyAndForget(ui.NotificationKindWarning, msg, ErrNotifyTTL)
			return nil, err
		}

		if a.notifyUsage {
			_ = a.notifier.NotifyAndForget(ui.NotificationKindAlert, makeSignedNotifyMsg(ctx, key.Name()), NotifyTTL)
		}
		return out, err
	}

	return nil, fmt.Errorf("unknown key: %v", fp)
}

// Extension processes a custom extension request. Standard-compliant agents are not
// required to support any extensions, but this method allows agents to implement
// vendor-specific methods or add experimental features. See [PROTOCOL.agent] section 4.7.
// If agent extensions are unsupported entirely this method MUST return an
// ErrExtensionUnsupported error. Similarly, if just the specific extensionType in
// the request is unsupported by the agent then ErrExtensionUnsupported MUST be
// returned.
//
// In the case of success, since [PROTOCOL.agent] section 4.7 specifies that the contents
// of the response are unspecified (including the type of the message), the complete
// response will be returned as a []byte slice, including the "type" byte of the message.
func (a *RealAgent) Extension(ctx context.Context, extensionType string, contents []byte) ([]byte, error) {
	if a.sudoKey == nil {
		return nil, agent.ErrExtensionUnsupported
	}

	if extensionType == SudoCheckExtensionName {
		return nil, nil
	}

	if extensionType != SudoExtensionName {
		return nil, agent.ErrExtensionUnsupported
	}

	if a.lockPassphrase != nil {
		return nil, errors.New(".Extension method is not allowed on agent locked")
	}

	processSudoReq := func(contents []byte) ([]byte, string, error) {
		var req psudo.SudoReq
		err := proto.Unmarshal(contents, &req)
		if err != nil {
			return nil, "", fmt.Errorf("parse fail: %w", err)
		}

		if len(req.Nonce) < 32 {
			return nil, "", fmt.Errorf("invalid nonce length: %d < 32", len(req.Nonce))
		}

		if req.Hostname == "" {
			return nil, "", errors.New("no hostname in request")
		}

		a.mu.Lock()
		defer a.mu.Unlock()

		sign, err := a.sshSign(ctx, a.sudoKey, req.Nonce, 0)
		if err != nil {
			return nil, "", fmt.Errorf("sign fail: %w", err)
		}

		rsp := psudo.SudoRsp{
			PubKey: a.sudoKey.AgentKey().Marshal(),
			Signature: &psudo.Signature{
				Format: sign.Format,
				Blob:   sign.Blob,
				Rest:   sign.Rest,
			},
		}

		rspBytes, err := proto.Marshal(&rsp)
		if err != nil {
			return nil, "", fmt.Errorf("failed to marshal response: %w", err)
		}

		ctxlog.Info(ctx, a.log, "signed sudo request", log.String("hostname", req.Hostname))
		return rspBytes, "", nil
	}

	rsp, hostname, err := processSudoReq(contents)
	if err != nil {
		var msg string
		if nestedErr := errors.Unwrap(err); nestedErr != nil {
			msg = fmt.Sprintf("Sudo request failed: %s", nestedErr)
		} else {
			msg = fmt.Sprintf("Sudo request failed: %s", err)
		}

		_ = a.notifier.NotifyAndForget(ui.NotificationKindWarning, msg, ErrNotifyTTL)
		return nil, err
	}

	if a.notifyUsage {
		_ = a.notifier.NotifyAndForget(ui.NotificationKindAlert, fmt.Sprintf("Signed sudo request for host: %s", hostname), NotifyTTL)
	}
	return rsp, nil
}

// Signers returns signers for all the known keys.
func (a *RealAgent) Signers(_ context.Context) ([]ssh.Signer, error) {
	return nil, errors.New("not implemented")
}

// Add adds a private key to the agent.
func (a *RealAgent) Add(ctx context.Context, newKey agent.AddedKey) error {
	switch a.conf.Confirm {
	case config.ConfirmKindAny, config.ConfirmKindAdded:
		newKey.ConfirmBeforeUse = true
	}

	key, err := agentkey.AddedRawKey(newKey, false)
	if err != nil {
		return fmt.Errorf("can't parse added ssh key: %w", err)
	}

	a.mu.Lock()
	defer a.mu.Unlock()

	if a.lockPassphrase != nil {
		return errors.New(".Add method is not allowed on agent locked")
	}

	fp := key.LegacyFingerprint()
	for _, key := range a.sshKeys {
		if key.LegacyFingerprint() == fp {
			// already exists
			ctxlog.Info(ctx, a.log, "skip key adding: already exists", log.String("name", key.Name()), log.String("fingerprint", key.Fingerprint()))
			return nil
		}
	}

	ctxlog.Info(ctx, a.log, "added key",
		log.String("name", key.Name()),
		log.String("fingerprint", key.Fingerprint()),
		log.String("legacy_fingerprint", key.LegacyFingerprint()),
	)
	a.sshKeys = append(a.sshKeys, key)
	a.reorderKeys()

	return nil
}

// Remove removes all identities with the given public key.
func (a *RealAgent) Remove(ctx context.Context, reqKey ssh.PublicKey) error {
	a.mu.Lock()
	defer a.mu.Unlock()

	if a.lockPassphrase != nil {
		return errors.New(".Remove method is not allowed on agent locked")
	}

	fp := ssh.FingerprintLegacyMD5(reqKey)
	for i, key := range a.sshKeys {
		if key.LegacyFingerprint() != fp {
			continue
		}

		ctxlog.Info(ctx, a.log, "remove key", log.String("name", key.Name()), log.String("fingerprint", key.Fingerprint()))
		a.sshKeys[i] = a.sshKeys[len(a.sshKeys)-1]
		a.sshKeys = a.sshKeys[:len(a.sshKeys)-1]
		a.reorderKeys()
		return nil
	}

	return fmt.Errorf("unknown key: %v", fp)
}

// RemoveAll removes all identities.
func (a *RealAgent) RemoveAll(_ context.Context) error {
	a.mu.Lock()
	defer a.mu.Unlock()

	if a.lockPassphrase != nil {
		return errors.New(".RemoveAll method is not allowed on agent locked")
	}

	a.sshKeys = a.sshKeys[:0]
	return nil
}

// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
func (a *RealAgent) Lock(ctx context.Context, passphrase []byte) error {
	a.mu.Lock()
	defer a.mu.Unlock()

	if a.lockPassphrase != nil {
		return errors.New("can't lock locked agent")
	}

	a.lockPassphrase = encodePassphrase(passphrase)
	ctxlog.Info(ctx, a.log, "agent locked")
	return nil
}

// Unlock undoes the effect of Lock
func (a *RealAgent) Unlock(ctx context.Context, passphrase []byte) error {
	a.mu.Lock()
	defer a.mu.Unlock()

	if a.lockPassphrase == nil {
		return errors.New("can't unlock not locked agent")
	}

	passphrase = encodePassphrase(passphrase)
	if subtle.ConstantTimeCompare(passphrase, a.lockPassphrase) != 1 {
		return errors.New("incorrect passphrase")
	}

	a.lockPassphrase = nil
	ctxlog.Info(ctx, a.log, "agent unlocked")
	return nil
}

func (a *RealAgent) sshSign(ctx context.Context, key agentkey.Key, data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) {
	if !a.lastAuthErrTime.IsZero() {
		since := time.Since(a.lastAuthErrTime)
		if since < AuthErrBreak {
			return nil, fmt.Errorf(
				"sign blocked due to frequent authentication errors and will be unblocked after: %.1fs",
				(AuthErrBreak - since).Seconds(),
			)
		}
		a.lastAuthErrTime = time.Time{}
	}

	if key.ConfirmBeforeUse() && a.confirm != nil {
		msg := fmt.Sprintf("Allow use of key %s?\nKey fingerprint %s.", key.Name(), key.Fingerprint())
		allow, err := a.confirm.Confirm(msg)
		if err != nil {
			return nil, fmt.Errorf("unable to confirm key usage: %w", err)
		}

		if !allow {
			return nil, errors.New("user not allowed to use this key")
		}
	}

	if key.TouchNeeded() {
		notify, err := a.notifier.ScheduleNotify(NotifyPostpone, ui.NotificationKindUserInteraction, "Waiting for touch...")
		if err != nil {
			ctxlog.Warn(ctx, a.log, "failed to schedule notification", log.Error(err))
		} else {
			defer notify.Close()
		}
	}

	sign, err := key.Sign(data, flags)
	if err != nil {
		var authErr *keyring.AuthError
		if errors.As(err, &authErr) {
			a.lastAuthErrTime = time.Now()
		}

		ctxlog.Warn(ctx, a.log, "sign fail",
			log.String("name", key.Name()),
			log.String("fingerprint", key.Fingerprint()),
			log.Error(err))
		return nil, err
	}

	ctxlog.Info(ctx, a.log, "signed",
		log.String("name", key.Name()),
		log.String("fingerprint", key.Fingerprint()))
	return sign, nil
}

func encodePassphrase(passphrase []byte) []byte {
	h := sha256.New()
	h.Write(passphrase)
	return h.Sum(nil)
}

func makeSignedNotifyMsg(ctx context.Context, keyName string) string {
	fields := ctxlog.ContextFields(ctx)
	var procName string
	for _, f := range fields {
		if f.Key() != "peer_name" {
			continue
		}

		procName = f.String()
		break
	}

	if procName == "" {
		return fmt.Sprintf("Signed request: %s", keyName)
	}

	return fmt.Sprintf("Signed request for %s: %s", procName, keyName)
}
