package agentkey

import (
	"crypto/rand"
	"fmt"
	"sync"
	"time"

	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"

	"a.yandex-team.ru/security/skotty/skotty/pkg/sshutil"
)

type Key struct {
	name        string
	fingerprint string
	stale       bool
	mustConfirm bool
	expire      *time.Time
	validAfter  time.Time
	validBefore time.Time
	signer      ssh.Signer
	agentKey    *agent.Key
	mu          sync.RWMutex
}

func NewKey(key agent.AddedKey) (*Key, error) {
	signer, err := ssh.NewSignerFromKey(key.PrivateKey)
	if err != nil {
		return nil, err
	}

	var validBefore, validAfter time.Time
	if cert := key.Certificate; cert != nil {
		validAfter = time.Unix(int64(key.Certificate.ValidAfter), 0)
		validBefore = time.Unix(int64(key.Certificate.ValidBefore), 0)
		signer, err = ssh.NewCertSigner(cert, signer)
		if err != nil {
			return nil, err
		}
	}

	pubKey := signer.PublicKey()
	fingerprint := ssh.FingerprintSHA256(pubKey)

	var expire *time.Time
	if key.LifetimeSecs > 0 {
		t := time.Now().Add(time.Duration(key.LifetimeSecs) * time.Second)
		expire = &t
	}

	var name string
	if key.Comment != "" {
		name = key.Comment
	} else {
		name = sshutil.Fingerprint(pubKey)
	}

	return &Key{
		name:        name,
		fingerprint: fingerprint,
		mustConfirm: key.ConfirmBeforeUse,
		validAfter:  validAfter,
		validBefore: validBefore,
		signer:      signer,
		expire:      expire,
		agentKey: &agent.Key{
			Format:  pubKey.Type(),
			Blob:    pubKey.Marshal(),
			Comment: key.Comment,
		},
	}, nil
}

func (k *Key) Name() string {
	return k.name
}

func (k *Key) Fingerprint() string {
	return k.fingerprint
}

func (k *Key) CanExpire() bool {
	return k.expire != nil
}

func (k *Key) IsExpired() bool {
	k.mu.RLock()
	defer k.mu.RUnlock()

	return k.expire != nil && time.Now().After(*k.expire)
}

func (k *Key) IsStale() bool {
	k.mu.RLock()
	defer k.mu.RUnlock()

	return k.stale
}

func (k *Key) SetStale(stale bool) {
	k.mu.Lock()
	defer k.mu.Unlock()

	k.stale = stale
}

func (k *Key) ValidAfter() time.Time {
	return k.validAfter
}

func (k *Key) ValidBefore() time.Time {
	return k.validBefore
}

func (k *Key) AgentKey() *agent.Key {
	return k.agentKey
}

func (k *Key) Comment() string {
	return k.agentKey.Comment
}

func (k *Key) MustConfirm() bool {
	return k.mustConfirm
}

func (k *Key) Sign(data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) {
	if !k.validBefore.IsZero() && time.Now().After(k.validBefore) {
		return nil, fmt.Errorf("certificate expired at %s, you must to renew them", k.validBefore)
	}

	if flags == 0 {
		return k.signer.Sign(rand.Reader, data)
	}

	algorithmSigner, ok := k.signer.(ssh.AlgorithmSigner)
	if !ok {
		return nil, fmt.Errorf("signature does not support non-default signature algorithm: %T", k.signer)
	}

	var algorithm string
	switch flags {
	case agent.SignatureFlagRsaSha256:
		algorithm = ssh.KeyAlgoRSASHA256
	case agent.SignatureFlagRsaSha512:
		algorithm = ssh.KeyAlgoRSASHA512
	default:
		return nil, fmt.Errorf("unsupported signature flags: %d", flags)
	}

	return algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm)
}

func (k *Key) Update(b *Key) {
	k.mu.Lock()
	defer k.mu.Unlock()

	k.expire = b.expire
}
