package softattest

import (
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/md5"
	"crypto/rand"
	"crypto/x509"
	"crypto/x509/pkix"
	_ "embed"
	"encoding/asn1"
	"encoding/hex"
	"errors"
	"fmt"
	"math/big"
	"sync"
	"time"

	"a.yandex-team.ru/security/skotty/libs/certutil"
)

//go:embed skotty-soft.priv
var privKeyBytes []byte

//go:embed skotty-soft.pub
var PubKeyBytes []byte

type PINPolicy uint8

const (
	PINPolicyNever PINPolicy = iota + 1
	PINPolicyOnce
	PINPolicyAlways
)

type TouchPolicy uint8

const (
	TouchPolicyNever TouchPolicy = iota + 1
	TouchPolicyCached
	TouchPolicyAlways
)

var (
	extIDKeyPolicy    = asn1.ObjectIdentifier([]int{1, 3, 6, 1, 4, 1, 31337, 1, 1})
	extIDSerialNumber = asn1.ObjectIdentifier([]int{1, 3, 6, 1, 4, 1, 31337, 1, 2})
	attestatorOnce    sync.Once
	attestator        *Attestator
)

type Attestator struct {
	Public     *x509.Certificate
	private    *ecdsa.PrivateKey
	attestPub  *x509.Certificate
	attestPriv *ecdsa.PrivateKey
	username   func() (string, error)
	machineID  func() (string, error)
	tokenID    string
	mu         sync.Mutex
}

func SharedAttestator() *Attestator {
	attestatorOnce.Do(func() {
		attestator = newAttestator(privKeyBytes, PubKeyBytes)
	})

	return attestator
}

func newAttestator(priv, pub []byte) *Attestator {
	privKey, err := certutil.PemToECPriv(priv)
	if err != nil {
		panic(fmt.Sprintf("failed to parse embeded priv key: %v", err))
	}

	pubKey, err := certutil.PemToCert(pub)
	if err != nil {
		panic(fmt.Sprintf("failed to parse embeded pub key: %v", err))
	}

	return &Attestator{
		Public:    pubKey,
		private:   privKey,
		username:  userName,
		machineID: machineID,
	}
}

func (a *Attestator) TokenSerial() (string, error) {
	if a.tokenID != "" {
		return a.tokenID, nil
	}

	u, err := a.username()
	if err != nil {
		return "", fmt.Errorf("failed to get current user: %w", err)
	}

	mID, err := a.machineID()
	if err != nil {
		return "", fmt.Errorf("failed to get machine ID: %w", err)
	}

	h := md5.New()
	h.Write([]byte(u))
	h.Write([]byte("skotty-soft-attest"))
	h.Write([]byte(mID))
	a.tokenID = hex.EncodeToString(h.Sum(nil))
	return a.tokenID, nil
}

func (a *Attestator) Certificate() (*x509.Certificate, error) {
	a.mu.Lock()
	defer a.mu.Unlock()

	if a.attestPub != nil && time.Since(a.attestPub.NotAfter) < 1*time.Hour {
		return a.attestPub, nil
	}

	if err := a.initAttestCrt(); err != nil {
		return nil, err
	}

	return a.attestPub, nil
}

func (a *Attestator) Attest(crt *x509.Certificate, pinPolicy PINPolicy, touchPolicy TouchPolicy) (*x509.Certificate, error) {
	_, err := a.Certificate()
	if err != nil {
		return nil, fmt.Errorf("failed to get attestation certificate: %w", err)
	}

	a.mu.Lock()
	defer a.mu.Unlock()

	tokenSerial, err := a.TokenSerial()
	if err != nil {
		return nil, fmt.Errorf("failed to get token serial: %w", err)
	}

	asnTokenSerial, err := asn1.Marshal(tokenSerial)
	if err != nil {
		return nil, fmt.Errorf("failed to marshal token serial: %w", err)
	}

	now := time.Now()
	serial := big.NewInt(time.Now().UnixNano())
	csr := x509.Certificate{
		SerialNumber: serial,
		Subject: pkix.Name{
			CommonName:         "Skotty Self Cert Attest",
			SerialNumber:       crt.SerialNumber.String(),
			Country:            []string{"RU"},
			Province:           []string{"Moscow"},
			Locality:           []string{"Moscow"},
			Organization:       []string{"Yandex"},
			OrganizationalUnit: []string{"Infra"},
		},
		Issuer:    a.attestPub.Subject,
		NotBefore: now.Add(10 * time.Minute),
		NotAfter:  now.Add(1 * time.Hour),
		KeyUsage:  x509.KeyUsageDigitalSignature,
		ExtraExtensions: []pkix.Extension{
			{
				Id:    extIDKeyPolicy,
				Value: []byte{byte(pinPolicy), byte(touchPolicy)},
			},
			{
				Id:    extIDSerialNumber,
				Value: asnTokenSerial,
			},
		},
	}

	cert, err := x509.CreateCertificate(rand.Reader, &csr, a.attestPub, crt.PublicKey, a.attestPriv)
	if err != nil {
		return nil, fmt.Errorf("failed to issue attest certificate: %w", err)
	}

	attestPub, err := x509.ParseCertificate(cert)
	if err != nil {
		return nil, fmt.Errorf("failed to parse issued attest certificate: %w", err)
	}

	return attestPub, nil
}

func (a *Attestator) initAttestCrt() error {
	tokenSerial, err := a.TokenSerial()
	if err != nil {
		return fmt.Errorf("failed to get token serial: %w", err)
	}

	serial := big.NewInt(0)
	if _, ok := serial.SetString(tokenSerial, 16); !ok {
		return errors.New("can't setup soft serial")
	}

	attestPriv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
		return fmt.Errorf("failed to generate attestation private key: %w", err)
	}

	now := time.Now()
	csr := x509.Certificate{
		SerialNumber: serial,
		Subject: pkix.Name{
			CommonName:         fmt.Sprintf("Skotty Self Attest for %s", tokenSerial),
			SerialNumber:       serial.String(),
			Country:            []string{"RU"},
			Province:           []string{"Moscow"},
			Locality:           []string{"Moscow"},
			Organization:       []string{"Yandex"},
			OrganizationalUnit: []string{"Infra"},
		},
		Issuer:                a.Public.Subject,
		NotBefore:             now.Add(10 * time.Minute),
		NotAfter:              now.Add(12 * time.Hour),
		IsCA:                  true,
		BasicConstraintsValid: true,
		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
	}

	cert, err := x509.CreateCertificate(rand.Reader, &csr, a.Public, &attestPriv.PublicKey, a.private)
	if err != nil {
		return fmt.Errorf("failed to issue attestation certificate: %w", err)
	}

	attestPub, err := x509.ParseCertificate(cert)
	if err != nil {
		return fmt.Errorf("failed to parse issued attestation certificate: %w", err)
	}

	a.attestPriv = attestPriv
	a.attestPub = attestPub
	return nil
}
