package scenario

import (
	"context"
	"crypto/x509"
	"errors"
	"fmt"
	"os"
	"strings"
	"time"

	"github.com/cenkalti/backoff/v4"

	"a.yandex-team.ru/security/skotty/libs/certutil"
	"a.yandex-team.ru/security/skotty/libs/skotty"
	"a.yandex-team.ru/security/skotty/skotty/internal/config"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
)

type Enroll struct {
	*Certificates
}

func NewEnroll(opts ...Option) *Enroll {
	return &Enroll{
		Certificates: NewCertificates(opts...),
	}
}

func (e *Enroll) Request(kr keyring.Keyring, next func(*skotty.RequestEnrollmentRsp) error) error {
	e.LogInfo("request enrollment")

	hostname, _ := os.Hostname()
	if hostname == "" {
		hostname = "n/a"
	} else {
		hostname = strings.ToLower(hostname)
	}

	serial, err := kr.Serial()
	if err != nil {
		return fmt.Errorf("failed to get token serial: %w", err)
	}

	rsp, err := e.skottyService().RequestEnrollment(context.Background(), &skotty.RequestEnrollmentReq{
		TokenSerial: serial,
		Hostname:    hostname,
		TokenType:   kr.TokenType(),
		TokenName:   fmt.Sprintf("%s on %s", kr.HumanName(), hostname),
	})
	if err != nil {
		return err
	}

	return next(rsp)
}

func (e *Enroll) WaitAndIssue(kr keyring.Keyring, req IssueEnrollmentReq, next func(*IssuedEnrollment) error) error {
	e.LogWait("wait auth and issue certificates")

	var (
		tokenAttestCert *x509.Certificate
		csrs            []skotty.RequestedCertificate
	)
	err := keyringCall(kr, func(tx keyring.Tx) error {
		var err error
		tokenAttestCert, err = tx.AttestationCertificate()
		if err != nil {
			return err
		}

		csrs, err = csrFromTokenCerts(tx, req.Certs)
		if err != nil {
			return err
		}

		return nil
	})
	if err != nil {
		return err
	}

	doEnroll := func(tokenAttestCert *x509.Certificate, csrs []skotty.RequestedCertificate) (rsp *skotty.IssueEnrollmentRsp, err error) {
		err = backoff.Retry(
			func() error {
				rsp, err = e.skottyService().IssueEnrollment(context.Background(), req.EnrollID, &skotty.IssueEnrollmentReq{
					AuthToken:       req.AuthToken,
					TokenType:       kr.TokenType(),
					AttestationCert: certutil.CertToPem(tokenAttestCert),
					Certificates:    csrs,
				})

				var svcErr *skotty.ServiceError
				if errors.As(err, &svcErr) && svcErr.Code == skotty.ServiceErrorUnauthorizedRequest {
					return err
				}

				return backoff.Permanent(err)
			},
			backoff.NewConstantBackOff(5*time.Second),
		)
		return
	}

	rsp, err := doEnroll(tokenAttestCert, csrs)
	if err != nil {
		return err
	}

	issuedCerts, err := parseIssuedCerts(req.Certs, rsp.Certificates)
	if err != nil {
		return err
	}

	return next(&IssuedEnrollment{
		RenewToken:    rsp.RenewToken,
		ExpiresAt:     time.Unix(rsp.ExpiresAt, 0),
		Certs:         issuedCerts,
		StaffUploaded: rsp.KeysUpdated,
		StaffErr:      rsp.KeysUpdateErr,
		EnrollInfo: config.EnrollInfo{
			EnrollID:    rsp.EnrollInfo.EnrollID,
			TokenSerial: rsp.EnrollInfo.TokenSerial,
			User:        rsp.EnrollInfo.User,
		},
	})
}

func (e *Enroll) UpdateKeys(kr keyring.Keyring, certs []SignedCert, next func(...keyring.KeyPurpose) error) error {
	e.LogInfo("updating keys")

	out := make([]keyring.KeyPurpose, len(certs))
	err := keyringCall(kr, func(tx keyring.Tx) error {
		for i, cert := range certs {
			if len(cert.SSHCert) == 0 {
				return fmt.Errorf("no ssh pub cert fpr keytype: %s", cert.KeyType)
			}

			if err := tx.SetCertificate(cert.KeyType, cert.Cert); err != nil {
				return fmt.Errorf("failed to update cert for keytype %q: %w", cert.KeyType, err)
			}

			if err := e.pubStore.SaveKey(tx.Name(), cert.KeyType, cert.SSHCert); err != nil {
				return fmt.Errorf("failed to save ssh pub key for keytype %q: %w", cert.KeyType, err)
			}

			out[i] = cert.KeyType
		}

		return nil
	})
	if err != nil {
		return err
	}

	return next(out...)
}
