package revoker

import (
	"bytes"
	"context"
	"crypto/md5"
	"crypto/rand"
	"errors"
	"fmt"
	"sort"
	"strconv"
	"time"

	"github.com/cenkalti/backoff/v4"
	"github.com/klauspost/compress/zstd"
	"github.com/stripe/krl"
	"golang.org/x/crypto/ssh"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/libs/go/xlock"
	"a.yandex-team.ru/security/skotty/service/internal/auditlog"
	"a.yandex-team.ru/security/skotty/service/internal/db"
	"a.yandex-team.ru/security/skotty/service/internal/mailer"
	"a.yandex-team.ru/security/skotty/service/internal/models"
	"a.yandex-team.ru/security/skotty/service/internal/signer"
	"a.yandex-team.ru/security/skotty/service/internal/staff"
	"a.yandex-team.ru/security/skotty/service/internal/storage"
	"a.yandex-team.ru/yt/go/yterrors"
)

const (
	KRLPath     = "krl/all.zst"
	DriftPeriod = 30 * time.Minute

	tickInternal = 1 * time.Hour
)

var errConflictLock = errors.New("conflict YT lock")

var _ Revoker = (*S3Revoker)(nil)

type S3Revoker struct {
	db        *db.DB
	lock      xlock.Locker
	zstd      *zstd.Encoder
	mailer    *mailer.Mailer
	caStore   *signer.CAStorage
	auditLog  *auditlog.AuditLog
	s3Store   *storage.Storage
	staff     *staff.Client
	revokeNow chan struct{}
	log       log.Logger
	ctx       context.Context
	cancelCtx context.CancelFunc
	closed    chan struct{}
}

func NewRevoker(opts ...Option) (*S3Revoker, error) {
	encoder, err := zstd.NewWriter(nil, zstd.WithEncoderConcurrency(1), zstd.WithEncoderLevel(zstd.SpeedDefault))
	if err != nil {
		return nil, fmt.Errorf("failed to create zstd encoder: %w", err)
	}

	ctx, cancelCtx := context.WithCancel(context.Background())

	r := &S3Revoker{
		zstd:      encoder,
		log:       &nop.Logger{},
		lock:      &xlock.NopLocker{},
		ctx:       ctx,
		cancelCtx: cancelCtx,
		closed:    make(chan struct{}),
	}

	for _, opt := range opts {
		if err := opt(r); err != nil {
			return nil, err
		}
	}

	if r.db == nil {
		return nil, errors.New("revoker can't work w/o db, please pass WithDB option")
	}

	if r.s3Store == nil {
		return nil, errors.New("revoker can't work w/o s3Store, please pass WithS3Store option")
	}

	if r.caStore == nil {
		return nil, errors.New("revoker can't work w/o caStore, please pass WithCAStore option")
	}

	if r.mailer == nil {
		return nil, errors.New("revoker can't work w/o mailer, please pass WithMailer option")
	}

	go r.loop()
	return r, nil
}

func (r *S3Revoker) loop() {
	defer close(r.closed)

	for {
		toNextWork := time.Until(
			time.Now().Add(tickInternal).Truncate(tickInternal),
		)
		t := time.NewTimer(toNextWork)

		select {
		case <-r.ctx.Done():
			t.Stop()
		case <-t.C:
			if err := r.Process(DriftPeriod); err != nil && !errors.Is(err, errConflictLock) {
				r.log.Error("can't process certs revocations", log.Error(err))
			}
		case <-r.revokeNow:
			err := backoff.RetryNotify(
				func() error {
					err := r.Process(0)
					if err == nil {
						return nil
					}

					return backoff.Permanent(err)
				},
				backoff.WithContext(backoff.NewExponentialBackOff(), r.ctx),
				func(err error, duration time.Duration) {
					r.log.Warn("unable to revoke revoked now certs", log.Duration("sleep", duration), log.Error(err))
				},
			)

			if err != nil {
				r.log.Error("can't process certs revocations", log.Error(err))
			}
		}
	}
}

func (r *S3Revoker) RevokeNow() {
	select {
	case r.revokeNow <- struct{}{}:
	default:
	}
}

func (r *S3Revoker) Process(driftPeriod time.Duration) error {
	tx, err := r.lock.Lock(r.ctx)
	if err != nil {
		if yterrors.ContainsErrorCode(err, yterrors.CodeConcurrentTransactionLockConflict) {
			r.log.Info("conflict YT lock", log.String("error", err.Error()))
			return errConflictLock
		}

		return fmt.Errorf("lock acquire: %w", err)
	}
	defer func() { _ = tx.Unlock() }()

	maxRevokedAt := time.Now().Add(-driftPeriod).Unix()

	revokeTokens := func() error {
		tokens, err := r.db.LookupRevokedTokens(r.ctx, maxRevokedAt)
		if err != nil {
			return err
		}

		for _, token := range tokens {
			func() {
				tfID := models.TFID(token.User, token.ID, token.EnrollID)
				r.log.Info("starts token revoking", log.String("tfid", tfID))

				if token.ID == "" || token.User == "" || token.EnrollID == "" {
					r.log.Error("invalid revoked token", log.Any("token", token))
					return
				}

				certs, err := r.db.LookupActiveTokenSSHKeysQuery(r.ctx, tfID)
				if err != nil {
					r.log.Error("unable to list token certs", log.String("tfid", tfID), log.Error(err))
					return
				}

				var fingerprints []string
				for _, cert := range certs {
					if cert.CertType != models.CertTypeLegacy {
						// revoke ONLY legacy keys
						continue
					}

					fingerprints = append(fingerprints, cert.SSHFingerprint)
				}

				r.log.Info("drop keys from the staff", log.String("tfid", tfID), log.Strings("fingerprints", fingerprints))
				if err := r.staff.RemoveSSHKey(r.ctx, token.User, fingerprints...); err != nil {
					r.log.Error("unable to revoke staff keys",
						log.String("tfid", tfID),
						log.Strings("fingerprints", fingerprints),
						log.Error(err))
				}

				r.log.Info("revoke token in DB", log.String("tfid", tfID))
				if err := r.db.RevokeToken(r.ctx, token); err != nil {
					r.log.Error("can't revoke token", log.String("tfid", tfID), log.Error(err))
					return
				}

				ctx, cancel := context.WithTimeout(r.ctx, 10*time.Second)
				defer cancel()

				r.mailer.TokenRevoked(ctx, token)
				if r.auditLog != nil {
					r.auditLog.Log(token.ID, token.EnrollID, "token and all its certificates have been revoked")
				}
				r.log.Info("token revoked", log.String("tfid", tfID))
			}()
		}

		return nil
	}

	generateKRL := func() ([]byte, error) {
		revokedCerts := make(map[string][]uint64)

		var lastCert models.Certificate
		for {
			certs, err := r.db.LookupRevokedCerts(r.ctx, maxRevokedAt, lastCert)
			if err != nil {
				return nil, fmt.Errorf("certs lookup failed: %w", err)
			}

			for _, cert := range certs {
				lastCert = cert
				serial, err := strconv.ParseUint(cert.Serial, 10, 64)
				if err != nil {
					r.log.Error("invalid cert serial", log.String("serial", cert.Serial))
					continue
				}

				if cert.CertType == models.CertTypeRenew || cert.CertType == models.CertTypeLegacy {
					// not needed
					continue
				}

				revokedCerts[cert.CAFingerprint] = append(revokedCerts[cert.CAFingerprint], serial)
			}

			if len(certs) < db.ListLimit {
				break
			}
		}

		cas := r.caStore.AllCAs()
		caSigners := make(map[string]ssh.Signer)
		for _, ca := range cas {
			for _, s := range ca.KRLSigners {
				caSigners[ssh.FingerprintSHA256(s.PublicKey())] = s
			}
		}

		out := krl.KRL{
			GeneratedDate: uint64(time.Now().Unix()),
		}

		var signers []ssh.Signer
		for caFp, serials := range revokedCerts {
			s, ok := caSigners[caFp]
			if !ok {
				r.log.Errorf("no CA for cert type: %q", caFp)
				continue
			}

			signers = append(signers, s)
			sshSerials := krl.KRLCertificateSerialList(serials)
			sort.Slice(sshSerials, func(i, j int) bool {
				return sshSerials[i] < sshSerials[j]
			})

			out.Sections = append(out.Sections, &krl.KRLCertificateSection{
				CA: s.PublicKey(),
				Sections: []krl.KRLCertificateSubsection{
					&sshSerials,
				},
			})
		}

		return out.Marshal(rand.Reader, signers...)
	}

	uploadKRL := func(krl []byte) error {
		dst := make([]byte, 0, len(krl))
		dst = r.zstd.EncodeAll(krl, dst)
		md5Hash := md5.Sum(dst)

		_, err := r.s3Store.Upload(r.ctx, &storage.UploadReq{
			Path:        KRLPath,
			ContentType: "application/krl+zst",
			MD5Sum:      md5Hash[:],
			Data:        bytes.NewReader(dst),
		})
		return err
	}

	r.log.Info("revoke tokens")
	if err := revokeTokens(); err != nil {
		r.log.Error("revoke tokens fail", log.Error(err))
	}

	r.log.Info("generate krl")
	krlBytes, err := generateKRL()
	if err != nil {
		return fmt.Errorf("generate krl fail: %w", err)
	}

	r.log.Info("upload krl")
	if err := uploadKRL(krlBytes); err != nil {
		return fmt.Errorf("failed to upload krl: %w", err)
	}

	r.log.Info("done")
	return nil
}

func (r *S3Revoker) Shutdown(ctx context.Context) {
	r.cancelCtx()

	select {
	case <-ctx.Done():
	case <-r.closed:
	}
}
