package watcher

import (
	"context"
	"fmt"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/skotty/robossh/internal/agentkey"
	"a.yandex-team.ru/security/skotty/robossh/internal/issuer"
	"a.yandex-team.ru/security/skotty/robossh/internal/keystore"
	"a.yandex-team.ru/security/skotty/robossh/internal/logger"
)

const (
	tickInternal = 30 * time.Minute
	certTTL      = 24 * time.Hour
)

type Watcher struct {
	issuer      issuer.Issuer
	keys        *keystore.Store
	ctx         context.Context
	cancelCtx   context.CancelFunc
	currentKeys map[string]struct{}
	oldKeys     map[string]struct{}
	stopped     chan struct{}
}

func NewWatcher(keys *keystore.Store, opts ...Option) *Watcher {
	ctx, cancel := context.WithCancel(context.Background())
	w := &Watcher{
		issuer:      &issuer.NopIssuer{},
		keys:        keys,
		ctx:         ctx,
		cancelCtx:   cancel,
		currentKeys: make(map[string]struct{}),
		oldKeys:     make(map[string]struct{}),
		stopped:     make(chan struct{}),
	}

	for _, opt := range opts {
		opt(w)
	}

	return w
}

func (w *Watcher) Run() {
	defer close(w.stopped)

	ticker := time.NewTicker(tickInternal)
	defer ticker.Stop()

	for {
		select {
		case <-w.ctx.Done():
			return
		case <-ticker.C:
		}

		w.CleanupCertificates()
		if w.isCertsExpired() {
			if err := w.UpdateCertificates(w.ctx); err != nil {
				logger.Error("can't update certificates", log.Error(err))
			}
		} else {
			logger.Info("skip certs update: all certs up to date")
		}
	}
}

func (w *Watcher) CleanupCertificates() {
	for fp := range w.oldKeys {
		delete(w.oldKeys, fp)
		if w.keys.Remove(fp) {
			logger.Info("removed old certificate", log.String("fingerprint", fp))
		}
	}
}

func (w *Watcher) UpdateCertificates(ctx context.Context) error {
	logger.Info("starts update certificates")

	logger.Info("issue certificates")
	sshKeys, err := w.issuer.Certificates(ctx)
	if err != nil {
		return fmt.Errorf("issue certs: %w", err)
	}

	agentKeys := make([]*agentkey.Key, len(sshKeys))
	for i, sshKey := range sshKeys {
		key, err := agentkey.NewKey(sshKey)
		if err != nil {
			return fmt.Errorf("create agent key %q: %w", sshKey.Comment, err)
		}

		agentKeys[i] = key
	}

	logger.Info("add new certificates")
	currentKeys := make(map[string]struct{})
	for _, key := range agentKeys {
		w.keys.Add(key)
		logger.Info("added certificate", log.String("name", key.Name()), log.String("fingerprint", key.Fingerprint()))
		currentKeys[key.Fingerprint()] = struct{}{}
	}

	logger.Info("hide old certificates")
	w.keys.Range(func(key *agentkey.Key) bool {
		if _, ok := w.currentKeys[key.Fingerprint()]; ok {
			w.oldKeys[key.Fingerprint()] = struct{}{}
			key.SetStale(true)
			logger.Info("hide certificate", log.String("name", key.Name()), log.String("fingerprint", key.Fingerprint()))
		}
		return true
	})

	w.currentKeys = currentKeys

	logger.Info("certificates updated")
	return nil
}

func (w *Watcher) isCertsExpired() bool {
	for fp := range w.currentKeys {
		key, ok := w.keys.Get(fp)
		if !ok {
			continue
		}

		validBefore := key.ValidBefore()
		if validBefore.IsZero() {
			continue
		}

		validAfter := key.ValidAfter()
		if validAfter.IsZero() {
			validAfter = validBefore.Add(-certTTL)
		}

		if time.Since(validAfter) >= validBefore.Sub(validAfter)/2 {
			return true
		}
	}

	return false
}

func (w *Watcher) Shutdown(ctx context.Context) {
	w.cancelCtx()

	select {
	case <-ctx.Done():
	case <-w.stopped:
	}
}
