package scenario

import (
	"bytes"
	"context"
	"crypto/x509"
	"errors"
	"fmt"
	"os"
	"strings"
	"time"

	"github.com/cenkalti/backoff/v4"
	"golang.org/x/crypto/ssh"

	"a.yandex-team.ru/security/skotty/libs/certutil"
	"a.yandex-team.ru/security/skotty/libs/skotty"
	"a.yandex-team.ru/security/skotty/skotty/internal/config"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
)

type Renew struct {
	*Certificates
}

func NewRenew(opts ...Option) *Renew {
	return &Renew{
		Certificates: NewCertificates(opts...),
	}
}

func (e *Renew) RenewCerts(kr keyring.Keyring, next func([]TokenCert) error) error {
	e.LogInfo("renewing private keys if necessary")

	out := make([]TokenCert, 0, len(renewKeys))
	err := keyringCall(kr, func(tx keyring.Tx) error {
		tokenKeyTypes := make(map[keyring.KeyPurpose]struct{}, len(supportedKeys))
		for _, keyType := range tx.SupportedKeyTypes() {
			tokenKeyTypes[keyType] = struct{}{}
		}

		for _, keyType := range renewKeys {
			if _, ok := tokenKeyTypes[keyType]; !ok {
				continue
			}

			cert, err := tx.RenewCertificate(keyType)
			if err != nil {
				return fmt.Errorf("can't renew certificate for keyType %q: %w", keyType, err)
			}

			out = append(out, TokenCert{
				Cert:    cert,
				KeyType: keyType,
			})
		}

		return nil
	})
	if err != nil {
		return err
	}

	return next(out)
}

func (e *Renew) Request(kr keyring.Keyring, info RenewInfo, next func(*skotty.RequestRenewRsp) error) error {
	e.LogInfo("request renewing")

	hostname, _ := os.Hostname()
	if hostname == "" {
		hostname = "n/a"
	} else {
		hostname = strings.ToLower(hostname)
	}

	serial, err := kr.Serial()
	if err != nil {
		return fmt.Errorf("failed to get token serial: %w", err)
	}

	req := &skotty.RequestRenewReq{
		Hostname:    hostname,
		TokenType:   kr.TokenType(),
		TokenName:   fmt.Sprintf("%s on %s", kr.HumanName(), hostname),
		TokenSerial: serial,
	}

	switch {
	case info.EnrollInfo.User != "" && info.EnrollInfo.EnrollID != "" && info.EnrollInfo.TokenSerial == serial:
		req.EnrollmentID = info.EnrollInfo.EnrollID
		req.User = info.EnrollInfo.User
	case info.RenewToken != "":
		req.RenewToken = info.RenewToken
	default:
		err := keyringCall(kr, func(tx keyring.Tx) error {
			var certInfo *certutil.ExtraName
			var err error
			if keyring.ContainsKeyPurpose(tx.SupportedKeyTypes(), keyring.KeyPurposeRenew) {
				certInfo, err = e.parseCertExtraNames(tx, keyring.KeyPurposeRenew)
			} else {
				certInfo, err = e.parseCertExtraNames(tx, keyring.KeyPurposeSecure)
			}

			if err != nil {
				return err
			}

			req.User = certInfo.User
			req.EnrollmentID = certInfo.EnrollID
			return nil
		})

		if err != nil {
			return fmt.Errorf("can't get renew info: %w", err)
		}
	}

	rsp, err := e.skottyService().RequestRenewing(context.Background(), req)
	if err != nil {
		return fmt.Errorf("service fail: %w", err)
	}

	return next(rsp)
}

func (e *Renew) Approve(kr keyring.Keyring, renew *skotty.RequestRenewRsp) error {
	e.LogInfo("try to auto-approve")

	return keyringCall(kr, func(tx keyring.Tx) error {
		if !keyring.ContainsKeyPurpose(tx.SupportedKeyTypes(), keyring.KeyPurposeRenew) {
			return errors.New("not supported")
		}

		signPub, err := tx.Certificate(keyring.KeyPurposeRenew)
		if err != nil {
			return fmt.Errorf("can't get renew pub key: %w", err)
		}

		signKey, err := tx.Signer(keyring.KeyPurposeRenew)
		if err != nil {
			return fmt.Errorf("can't get renew priv key: %w", err)
		}

		sign, err := certutil.Sign(signKey, []byte(renew.AuthID))
		if err != nil {
			return fmt.Errorf("can't sign request: %w", err)
		}

		return e.skottyService().ApproveRenew(context.Background(), renew.EnrollmentID, &skotty.ApproveRenewReq{
			AuthToken: renew.AuthToken,
			Sign:      sign,
			SignKey:   certutil.CertToPem(signPub),
		})
	})
}

func (e *Renew) WaitAndIssue(kr keyring.Keyring, req IssueRenewReq, next func(*IssuedRenew) error) error {
	e.LogWait("wait auth and issue certificates")

	var (
		tokenAttestCert *x509.Certificate
		csrs            []skotty.RequestedCertificate
	)
	err := keyringCall(kr, func(tx keyring.Tx) error {
		var err error
		tokenAttestCert, err = tx.AttestationCertificate()
		if err != nil {
			return err
		}

		csrs, err = csrFromTokenCerts(tx, req.Certs)
		if err != nil {
			return err
		}
		return nil
	})

	if err != nil {
		return err
	}

	doRenew := func(tokenAttestCert *x509.Certificate, csrs []skotty.RequestedCertificate) (rsp *skotty.IssueRenewRsp, err error) {
		err = backoff.Retry(
			func() error {
				rsp, err = e.skottyService().IssueRenew(context.Background(), req.EnrollID, &skotty.IssueRenewReq{
					AuthToken:       req.AuthToken,
					TokenType:       kr.TokenType(),
					AttestationCert: certutil.CertToPem(tokenAttestCert),
					Certificates:    csrs,
				})

				var svcErr *skotty.ServiceError
				if errors.As(err, &svcErr) && svcErr.Code == skotty.ServiceErrorUnauthorizedRequest {
					return err
				}

				return backoff.Permanent(err)
			},
			backoff.NewConstantBackOff(5*time.Second),
		)
		return
	}

	rsp, err := doRenew(tokenAttestCert, csrs)
	if err != nil {
		return err
	}

	issuedCerts, err := parseIssuedCerts(req.Certs, rsp.Certificates)
	if err != nil {
		return err
	}

	return next(&IssuedRenew{
		RenewToken: rsp.RenewToken,
		ExpiresAt:  time.Unix(rsp.ExpiresAt, 0),
		Certs:      issuedCerts,
		EnrollInfo: config.EnrollInfo{
			EnrollID:    rsp.EnrollInfo.EnrollID,
			TokenSerial: rsp.EnrollInfo.TokenSerial,
			User:        rsp.EnrollInfo.User,
		},
	})
}

func (e *Renew) UpdateKeys(kr keyring.Keyring, curKeys []keyring.KeyPurpose, certs []SignedCert, next func(...keyring.KeyPurpose) error) error {
	e.LogInfo("updating keys")

	out := curKeys[:0]
	err := keyringCall(kr, func(tx keyring.Tx) error {
		oldKeys := make(map[config.KeyPurpose]struct{})
		for _, k := range curKeys {
			oldKeys[k] = struct{}{}
		}

		for _, cert := range certs {
			if len(cert.SSHCert) == 0 {
				return fmt.Errorf("no ssh pub cert fpr keytype: %s", cert.KeyType)
			}

			if err := tx.SetCertificate(cert.KeyType, cert.Cert); err != nil {
				return fmt.Errorf("failed to update cert for keytype %q: %w", cert.KeyType, err)
			}

			if err := e.pubStore.SaveKey(tx.Name(), cert.KeyType, cert.SSHCert); err != nil {
				return fmt.Errorf("failed to save ssh pub key for keytype %q: %w", cert.KeyType, err)
			}

			delete(oldKeys, cert.KeyType)
			out = append(out, cert.KeyType)
		}

		for k := range oldKeys {
			out = append(out, k)
		}

		return nil
	})
	if err != nil {
		return err
	}

	return next(out...)
}

func (e *Renew) RestoreKeys(kr keyring.Keyring, next func(...keyring.KeyPurpose) error) error {
	err := keyringCall(kr, func(tx keyring.Tx) error {
		for _, purpose := range restoredKeys {
			cert, err := tx.Certificate(purpose)
			if err != nil {
				return fmt.Errorf("unable to restore key %q: %w", purpose, err)
			}

			sshPub, err := ssh.NewPublicKey(cert.PublicKey)
			if err != nil {
				return fmt.Errorf("failed to parse pub key for keytype %q: %w", purpose, err)
			}

			sshPubBytes := bytes.TrimSpace(ssh.MarshalAuthorizedKey(sshPub))
			if err := e.pubStore.SaveKey(tx.Name(), purpose, sshPubBytes); err != nil {
				return fmt.Errorf("failed to save ssh pub key for keytype %q: %w", purpose, err)
			}
		}

		return nil
	})

	if err != nil {
		return err
	}

	return next(restoredKeys...)
}

func (e *Renew) parseCertExtraNames(tx keyring.Tx, keyPurpose keyring.KeyPurpose) (*certutil.ExtraName, error) {
	cert, err := tx.Certificate(keyPurpose)
	if err != nil {
		return nil, fmt.Errorf("get certificate for %s: %w", keyPurpose, err)
	}

	return certutil.ParseExtraNames(cert.Subject.Names)
}
