package syncer

import (
	"bytes"
	"context"
	"crypto/ecdsa"
	"crypto/rsa"
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"net"
	"path/filepath"
	"strings"

	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/libs/go/ioatomic"
	"a.yandex-team.ru/security/skotty/robossh/internal/issuer"
	"a.yandex-team.ru/security/skotty/robossh/internal/logger"
)

type CertSyncer struct {
	issuer issuer.Issuer
}

type AgentCert struct {
	*agent.Key
	Cert *ssh.Certificate
}

func NewCertSyncer(issuer issuer.Issuer) *CertSyncer {
	return &CertSyncer{
		issuer: issuer,
	}
}

func (c *CertSyncer) IssueCertificates() ([]agent.AddedKey, error) {
	logger.Info("issue certificates")
	defer logger.Info("issued")

	keys, err := c.issuer.Certificates(context.Background())
	if err != nil {
		return nil, err
	}

	// sanity check
	for _, k := range keys {
		if k.Certificate == nil {
			return nil, fmt.Errorf("invalid certificate %q: no cert", k.Comment)
		}

		if k.PrivateKey == nil {
			return nil, fmt.Errorf("invalid certificate %q: no priv", k.Comment)
		}
	}

	return keys, nil
}

func (c *CertSyncer) SyncAgent(socketPath string, keys ...agent.AddedKey) error {
	logger.Info("sync to authentication agent", log.String("path", socketPath))
	defer logger.Info("synced")

	if len(keys) == 0 {
		return nil
	}

	conn, err := net.Dial("unix", socketPath)
	if err != nil {
		return fmt.Errorf("failed to communicate with ssh-agent: %w", err)
	}

	defer func() {
		_ = conn.Close()
	}()

	agentClient := agent.NewClient(conn)
	agentKeys, err := agentClient.List()
	if err != nil {
		return fmt.Errorf("can't lits agent keys: %w", err)
	}

	for _, cert := range keys {
		if err := agentClient.Add(cert); err != nil {
			return fmt.Errorf("can't add certificate %q: %w", cert.Comment, err)
		}

		logger.Info("certificate added", log.String("key_id", cert.Certificate.KeyId))
	}

	for _, key := range c.extractSkottyCerts(agentKeys, keys[0].Certificate.ValidPrincipals) {
		if err := agentClient.Remove(key); err != nil {
			logger.Error("can't remove certificate", log.String("key_id", key.Cert.KeyId))
			continue
		}

		logger.Info("removed old certificate", log.String("key_id", key.Cert.KeyId))
	}
	return nil
}

func (c *CertSyncer) SyncDir(targetDir string, keys ...agent.AddedKey) error {
	logger.Info("sync to directory agent", log.String("path", targetDir))
	defer logger.Info("synced")

	if len(keys) == 0 {
		return nil
	}

	for _, key := range keys {
		keyID, err := ParseKeyID(key.Certificate.KeyId)
		if err != nil {
			return fmt.Errorf("parse keyid of key %q: %w", key.Comment, err)
		}

		filename := fmt.Sprintf("robossh_%s_%s", keyID.CertType, strings.Join(key.Certificate.ValidPrincipals, "_"))
		privPath := filepath.Join(targetDir, filename)
		privBytes, err := privToPem(key.PrivateKey)
		if err != nil {
			return fmt.Errorf("marshal private key %q: %w", key.Certificate.KeyId, err)
		}
		if err := ioatomic.WriteFileBytes(privPath, privBytes, 0o600); err != nil {
			return fmt.Errorf("save private key %q: %w", key.Certificate.KeyId, err)
		}

		certPath := filepath.Join(targetDir, filename+"-cert.pub")
		certBytes := ssh.MarshalAuthorizedKey(key.Certificate)
		if err := ioatomic.WriteFileBytes(certPath, certBytes, 0o600); err != nil {
			return fmt.Errorf("save certificate %q: %w", key.Certificate.KeyId, err)
		}

		logger.Info("certificate saved",
			log.String("key_id", key.Certificate.KeyId),
			log.String("priv_path", privPath),
			log.String("cert_path", certPath),
		)
	}

	return nil
}

func (c *CertSyncer) extractSkottyCerts(keys []*agent.Key, targetPrincipals []string) []AgentCert {
	var out []AgentCert
	for _, key := range keys {
		pub, err := ssh.ParsePublicKey(key.Blob)
		if err != nil {
			logger.Warn("can't parse key from agent", log.String("key", key.String()), log.Error(err))
			continue
		}

		cert, ok := pub.(*ssh.Certificate)
		if !ok {
			// not a certificate
			continue
		}

		if !strings.HasPrefix(cert.KeyId, "skotty:") {
			// not a skotty certificate
			continue
		}

		if len(cert.ValidPrincipals) != len(targetPrincipals) {
			continue
		}

		toSkip := false
		for i, p := range targetPrincipals {
			if cert.ValidPrincipals[i] != p {
				toSkip = true
				break
			}
		}
		if toSkip {
			continue
		}

		out = append(out, AgentCert{
			Key:  key,
			Cert: cert,
		})
	}

	return out
}

func privToPem(privateKey interface{}) ([]byte, error) {
	switch k := privateKey.(type) {
	case *ecdsa.PrivateKey:
		privKeyBytes, err := x509.MarshalECPrivateKey(k)
		if err != nil {
			return nil, err
		}

		var privBuf bytes.Buffer
		err = pem.Encode(&privBuf, &pem.Block{
			Type:  "EC PRIVATE KEY",
			Bytes: privKeyBytes,
		})
		return privBuf.Bytes(), err
	case *rsa.PrivateKey:
		var privBuf bytes.Buffer
		err := pem.Encode(&privBuf, &pem.Block{
			Type:  "RSA PRIVATE KEY",
			Bytes: x509.MarshalPKCS1PrivateKey(k),
		})
		return privBuf.Bytes(), err
	default:
		return nil, fmt.Errorf("unsupported private key type: %T", privateKey)
	}
}
