package supervisor

import (
	"context"
	"fmt"
	"os"
	"strings"
	"time"

	"golang.org/x/crypto/ssh"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/libs/go/sectools"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring/pubstore"
	"a.yandex-team.ru/security/skotty/skotty/internal/ui"
	"a.yandex-team.ru/security/skotty/skotty/internal/version"
)

const (
	tickInternal      = 20 * time.Minute
	verCheckInternal  = 48 * time.Hour
	certCheckInterval = 30 * time.Minute
	notifyTTL         = 10 * time.Minute
)

type AgentChecker struct {
	lastVerCheck   time.Time
	lastCertsCheck time.Time
	cc             *CertChecker
	notify         ui.Notifier
	log            log.Logger
	ctx            context.Context
	cancelCtx      context.CancelFunc
}

func NewAgentChecker(opts ...Option) *AgentChecker {
	ctx, cancelCtx := context.WithCancel(context.Background())
	out := AgentChecker{
		log:       &nop.Logger{},
		ctx:       ctx,
		cancelCtx: cancelCtx,
	}

	for _, opt := range opts {
		switch v := opt.(type) {
		case loggerOption:
			out.log = v.logger
		case certsCheckerOption:
			out.cc = &CertChecker{
				pubStore:    v.pubStore,
				keys:        v.keys,
				keyringName: v.keyringName,
			}
		case notifierOption:
			out.notify = v.notify
		}
	}

	return &out
}

func (c *AgentChecker) Start() {
	ticker := time.NewTicker(tickInternal)
	defer ticker.Stop()

	for {
		select {
		case <-c.ctx.Done():
			return
		case <-ticker.C:
		}

		now := time.Now()
		switch {
		case c.cc != nil && now.Sub(c.lastCertsCheck) > certCheckInterval:
			if err := c.checkCertExpires(); err == nil {
				c.lastCertsCheck = now
			}
		case now.Sub(c.lastVerCheck) > verCheckInternal:
			if err := c.checkNewVersion(); err == nil {
				c.lastVerCheck = now
			}
		}
	}
}

func (c *AgentChecker) Shutdown() {
	c.cancelCtx()
}

func (c *AgentChecker) checkNewVersion() error {
	svcOpts := []sectools.Option{
		sectools.WithPreferFastURL(),
		sectools.WithCurrentVersion(version.Full()),
	}

	if c, err := sectools.ParseChannel(os.Getenv("SKOTTY_CHANNEL")); err == nil {
		svcOpts = append(svcOpts, sectools.WithChannel(c))
	}

	svc := sectools.NewClient(version.ToolName, svcOpts...)
	isLatest, latestVersion, err := svc.IsLatestVersion(c.ctx, version.Full())
	if err != nil {
		c.log.Warn("can't check new version", log.Error(err))
		return err
	}

	if isLatest {
		return nil
	}

	c.log.Info("new version available",
		log.String("current", version.Full()),
		log.String("latest", latestVersion))

	if c.notify == nil {
		return nil
	}

	msg := fmt.Sprintf("new version available: %s", latestVersion)
	err = c.notify.NotifyAndForget(ui.NotificationKindAlert, msg, notifyTTL)
	if err != nil {
		c.log.Warn("can't send notification about new version", log.Error(err))
		return nil
	}

	return nil
}

func (c *AgentChecker) checkCertExpires() error {
	have, err := c.cc.HaveExpiresSoon()
	if err != nil {
		c.log.Warn("can't check certs expires", log.Error(err))
		return err
	}

	if !have {
		return nil
	}

	c.log.Info("some of certificates are about to expire, renew needed")

	if c.notify == nil {
		return nil
	}

	msg := "some of your certificates are about to expire, please renew them"
	err = c.notify.NotifyAndForget(ui.NotificationKindAlert, msg, notifyTTL)
	if err != nil {
		c.log.Warn("can't send notification about new version", log.Error(err))
		return nil
	}

	return nil
}

type CertChecker struct {
	pubStore     pubstore.PubStore
	lastNotifyAt time.Time
	keyringName  string
	keys         []keyring.KeyPurpose
}

func (c *CertChecker) HaveExpiresSoon() (bool, error) {
	now := time.Now()
	for _, keyPurpose := range c.keys {
		sshPubBytes, err := c.pubStore.ReadKey(c.keyringName, keyPurpose)
		if err != nil {
			continue
		}

		sshPub, _, _, _, err := ssh.ParseAuthorizedKey(sshPubBytes)
		if err != nil {
			return false, fmt.Errorf("failed to parse ssh pub key %q: %w", keyPurpose, err)
		}

		sshCert, ok := sshPub.(*ssh.Certificate)
		if !ok {
			continue
		}

		if !strings.HasPrefix(sshCert.KeyId, "skotty:") {
			// not a Skotty certificate
			continue
		}

		validAfter := time.Unix(int64(sshCert.ValidAfter), 0)
		validBefore := time.Unix(int64(sshCert.ValidBefore), 0)
		notifyPeriod, expired := checkCertExpire(now, validAfter, validBefore)
		if !expired {
			continue
		}

		if now.Sub(c.lastNotifyAt) < notifyPeriod {
			continue
		}

		c.lastNotifyAt = now
		return true, nil
	}

	return false, nil
}

func checkCertExpire(now, validAfter, validBefore time.Time) (time.Duration, bool) {
	validPeriod := validBefore.Sub(validAfter)
	left := validPeriod - now.Sub(validAfter)

	if validPeriod.Hours() <= 48 {
		// short-lived certs
		switch {
		case left.Hours() <= 1:
			return 30 * time.Minute, true
		case left.Hours() <= 3:
			return time.Hour, true
		default:
			return 0, false
		}
	}

	if validPeriod.Hours() <= 24*31 {
		// semi long-lived certs
		switch {
		case left.Hours() <= 6:
			return 30 * time.Minute, true
		case left.Hours() <= 24:
			return time.Hour, true
		default:
			return 0, false
		}
	}

	// regular
	switch {
	case left.Hours() <= 6:
		return 30 * time.Minute, true
	case left.Hours() <= 24:
		return time.Hour, true
	case left.Hours() <= 24*7:
		return 12 * time.Hour, true
	default:
		return 0, false
	}
}
