package signer

import (
	"bytes"
	"crypto/ecdsa"
	"crypto/rand"
	"crypto/x509"
	"crypto/x509/pkix"
	"errors"
	"fmt"
	"time"

	"golang.org/x/crypto/ssh"

	"a.yandex-team.ru/security/skotty/libs/certutil"
	"a.yandex-team.ru/security/skotty/libs/skotty"
	"a.yandex-team.ru/security/skotty/service/internal/config"
)

// temporary workaround for invalid renew certs
const certDrift = 7 * 24 * time.Hour

type certificate struct {
	PublicBytes    []byte
	Public         *x509.Certificate
	SSHPublicBytes []byte
	Private        *ecdsa.PrivateKey
	SSHSigner      ssh.Signer
	SSHFingerprint string
}

type SSHCert struct {
	Cert          *x509.Certificate
	SSHPub        ssh.PublicKey
	CAFingerprint string
}

type CA struct {
	Public           *x509.Certificate
	Private          *ecdsa.PrivateKey
	SSHSigner        ssh.Signer
	SSHPubBytes      []byte
	SSHFingerprint   string
	TrustedVerifiers []*x509.Certificate
	TrustedSSHPub    []byte
	KRLSigners       []ssh.Signer
}

type CAStorage struct {
	authorities map[skotty.CertType]CA
}

func NewCAStorage(authorities []config.CA) (*CAStorage, error) {
	out := CAStorage{
		authorities: make(map[skotty.CertType]CA),
	}

	for _, ca := range authorities {
		if _, ok := out.authorities[ca.Type]; ok {
			return nil, fmt.Errorf("duplicate ca type: %s", ca.Type)
		}

		parsed, err := parseCA(ca)
		if err != nil {
			return nil, fmt.Errorf("failed to parse ca %s key: %w", ca.Type, err)
		}

		out.authorities[ca.Type] = parsed
	}

	return &out, nil
}

func (s *CAStorage) Sign(certType skotty.CertType, csr *CertificateRequest) (*SSHCert, error) {
	ca, ok := s.authorities[certType]
	if !ok {
		return nil, fmt.Errorf("unsupported cert type: %s", certType)
	}

	issueCert := func() (*x509.Certificate, error) {
		crtTemplate := x509.Certificate{
			SerialNumber: NewSerial(),
			Subject: pkix.Name{
				CommonName:         fmt.Sprintf("%s@%s", csr.User, certType),
				Country:            []string{"RU"},
				Province:           []string{"Moscow"},
				Locality:           []string{"Moscow"},
				Organization:       []string{"Yandex"},
				OrganizationalUnit: []string{"Infra"},
				ExtraNames: []pkix.AttributeTypeAndValue{
					{
						Type:  certutil.ExtNameTokenID,
						Value: csr.TokenID,
					},
					{
						Type:  certutil.ExtNameEnrollID,
						Value: csr.EnrollID,
					},
					{
						Type:  certutil.ExtNameUser,
						Value: csr.User,
					},
				},
			},
			PublicKey:   csr.PublicKey,
			Issuer:      ca.Public.Subject,
			NotBefore:   csr.ValidAfter,
			NotAfter:    csr.ValidBefore,
			KeyUsage:    x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
			IsCA:        false,
			ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
		}

		cert, err := x509.CreateCertificate(rand.Reader, &crtTemplate, ca.Public, csr.PublicKey, ca.Private)
		if err != nil {
			return nil, fmt.Errorf("failed to issue certificate: %w", err)
		}

		out, err := x509.ParseCertificate(cert)
		if err != nil {
			return nil, fmt.Errorf("failed to parse issued certificate: %w", err)
		}

		return out, nil
	}

	issueSSHCert := func(cert *x509.Certificate) (ssh.PublicKey, error) {
		sshPubKey, err := ssh.NewPublicKey(cert.PublicKey)
		if err != nil {
			return nil, fmt.Errorf("failed to parse SSH pub key from certificate: %w", err)
		}

		var extensions map[string]string
		switch certType {
		case skotty.CertTypeSecure, skotty.CertTypeInsecure:
			extensions = map[string]string{
				"permit-X11-forwarding":   "",
				"permit-agent-forwarding": "",
				"permit-port-forwarding":  "",
				"permit-pty":              "",
				"permit-user-rc":          "",
			}
		}

		keyID := fmt.Sprintf("skotty:%s:%s:%s:%s", certType, csr.User, csr.TokenID, csr.EnrollID)
		sshCert := ssh.Certificate{
			Key:             sshPubKey,
			Serial:          cert.SerialNumber.Uint64(),
			CertType:        ssh.UserCert,
			KeyId:           keyID,
			ValidPrincipals: []string{csr.User},
			ValidAfter:      uint64(csr.ValidAfter.Unix()),
			ValidBefore:     uint64(csr.ValidBefore.Unix()),
			Permissions: ssh.Permissions{
				Extensions: extensions,
			},
		}

		if err := sshCert.SignCert(rand.Reader, ca.SSHSigner); err != nil {
			return nil, fmt.Errorf("sign failed: %w", err)
		}

		sshPubkey, err := ssh.ParsePublicKey(sshCert.Marshal())
		if err != nil {
			return nil, fmt.Errorf("failed to parse issued ssh certificate: %w", err)
		}

		return sshPubkey, nil
	}

	cert, err := issueCert()
	if err != nil {
		return nil, err
	}

	sshPub, err := issueSSHCert(cert)
	if err != nil {
		return nil, err
	}

	return &SSHCert{
		Cert:          cert,
		SSHPub:        sshPub,
		CAFingerprint: ca.SSHFingerprint,
	}, nil
}

func (s *CAStorage) Verify(certType skotty.CertType, cert *x509.Certificate, msg, sig []byte) (*certutil.ExtraName, error) {
	ca, ok := s.authorities[certType]
	if !ok {
		return nil, fmt.Errorf("unsupported cert type: %s", certType)
	}

	now := time.Now()
	if now.Before(cert.NotBefore) {
		return nil, fmt.Errorf("current time %s is before %s", now.Format(time.RFC3339), cert.NotBefore.Format(time.RFC3339))
	}

	// TODO(buglloc): remove me after 2 month
	if now.Add(-certDrift).After(cert.NotAfter) {
		return nil, fmt.Errorf("current time %s is after %s", now.Format(time.RFC3339), cert.NotAfter.Format(time.RFC3339))
	}

	if err := ca.Public.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature); err != nil {
		return nil, err
	}

	ok, err := certutil.Verify(cert, msg, sig)
	if err != nil {
		return nil, err
	}

	if !ok {
		return nil, errors.New("invalid signature")
	}

	return certutil.ParseExtraNames(cert.Subject.Names)
}

func (s *CAStorage) AllCAs() map[skotty.CertType]CA {
	return s.authorities
}

func (s *CAStorage) CA(certType skotty.CertType) (CA, bool) {
	out, ok := s.authorities[certType]
	return out, ok
}

func parseCA(ca config.CA) (CA, error) {
	curCert, err := parseCertificateInfo(ca.Current)
	if err != nil {
		return CA{}, fmt.Errorf("can't parse 'current' cert: %w", err)
	}

	nextCert, err := parseCertificateInfo(ca.Next)
	if err != nil {
		return CA{}, fmt.Errorf("can't parse 'next' cert: %w", err)
	}

	prevCert, err := parseCertificateInfo(ca.Prev)
	if err != nil {
		return CA{}, fmt.Errorf("can't parse 'prev' cert: %w", err)
	}

	return CA{
		Public:           curCert.Public,
		Private:          curCert.Private,
		SSHSigner:        curCert.SSHSigner,
		SSHFingerprint:   curCert.SSHFingerprint,
		SSHPubBytes:      curCert.SSHPublicBytes,
		TrustedVerifiers: buildTrustedVerifiers(curCert, nextCert, prevCert),
		TrustedSSHPub:    buildTrustedSSHPub(curCert, nextCert, prevCert),
		KRLSigners:       buildKRLSigners(curCert, nextCert, prevCert),
	}, nil
}

func buildTrustedVerifiers(cur, next, prev *certificate) []*x509.Certificate {
	out := []*x509.Certificate{
		cur.Public,
		next.Public,
	}

	if prev != nil {
		out = append(out, prev.Public)
	}

	return out
}

func buildTrustedSSHPub(cur, next, prev *certificate) []byte {
	var out bytes.Buffer

	out.Write(cur.SSHPublicBytes)
	out.WriteByte('\n')
	out.Write(next.SSHPublicBytes)
	if prev != nil {
		out.WriteByte('\n')
		out.Write(prev.SSHPublicBytes)
	}
	return out.Bytes()
}

func buildKRLSigners(cur, _, prev *certificate) []ssh.Signer {
	out := []ssh.Signer{
		cur.SSHSigner,
	}

	if prev != nil {
		out = append(out, prev.SSHSigner)
	}

	return out
}
