package main

import (
	"context"
	"errors"
	"fmt"
	"os"
	"os/exec"
	"os/signal"
	"path/filepath"
	"strings"
	"syscall"
	"time"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/libs/go/daemon"
	"a.yandex-team.ru/security/skotty/libs/skotty"
	"a.yandex-team.ru/security/skotty/robossh/internal/agent"
	"a.yandex-team.ru/security/skotty/robossh/internal/issuer"
	"a.yandex-team.ru/security/skotty/robossh/internal/keystore"
	"a.yandex-team.ru/security/skotty/robossh/internal/logger"
	"a.yandex-team.ru/security/skotty/robossh/internal/sshutil"
	"a.yandex-team.ru/security/skotty/robossh/internal/version"
	"a.yandex-team.ru/security/skotty/robossh/internal/watcher"
)

const (
	daemonizeTimeout  = 5 * time.Minute
	shutdownDeadline  = 1 * time.Minute
	updateTimeout     = 1 * time.Minute
	parentCheckPeriod = 10 * time.Second
)

type startFlags struct {
	bindAddress    string
	foregroundMode bool
	ttl            string
	withCerts      bool
	certTypes      []string
	nArgs          []string
}

type childInfo struct {
	PID        int
	SocketPath string
}

func doStart(flags startFlags) error {
	// sanity check
	if flags.withCerts {
		token, ok := os.LookupEnv("ROBOSSH_TOKEN")
		if !ok {
			return errors.New("--with-certs requires env[ROBOSSH_TOKEN], but it's not provided")
		}

		if token == "" {
			return errors.New("--with-certs requires env[ROBOSSH_TOKEN], but it's empty")
		}
	}

	notifyStart := func(p string) {
		_, _ = os.Stdout.Write(sshutil.SSHAgentScript(p, os.Getpid()))
	}

	notifyError := func(err error) {
	}

	ppid := 0
	if !flags.foregroundMode {
		d := daemon.NewDaemon("robossh-agent")
		if d.IsParent() {
			return doParentWork(d, flags)
		}

		ppid = d.Ppid()
		notifyStart = func(socketPath string) {
			err := d.NotifyStarted(childInfo{
				PID:        os.Getpid(),
				SocketPath: socketPath,
			})
			if err != nil {
				logger.Error("failed to notify parent", log.Error(err))
			}
		}

		notifyError = func(err error) {
			if err := d.NotifyError(err); err != nil {
				logger.Error("failed to notify parent", log.Error(err))
			}
		}
	}

	if flags.bindAddress == "" {
		tmpDir, err := os.MkdirTemp("", "robossh-*")
		if err != nil {
			err = fmt.Errorf("can't create temporary folder for socket: %w", err)
			notifyError(err)
			return err
		}

		defer func() {
			_ = os.RemoveAll(tmpDir)
		}()

		flags.bindAddress = filepath.Join(tmpDir, fmt.Sprintf("agent.%d", os.Getppid()))
	}

	waitCh, err := doChildWork(ppid, flags)
	if err != nil {
		notifyError(err)
		return err
	}

	logger.Info("agent started", log.String("version", version.Full()))
	notifyStart(flags.bindAddress)

	<-waitCh
	return nil
}

func doChildWork(ppid int, flags startFlags) (<-chan struct{}, error) {
	var keysTTL time.Duration
	if flags.ttl != "" {
		var err error
		keysTTL, err = sshutil.ParseDuration(flags.ttl)
		if err != nil {
			return nil, fmt.Errorf("invalid keys lifetime: %w", err)
		}
	}

	socketPath := flags.bindAddress
	if socketPath == "" {
		tmpDir, err := os.MkdirTemp("", "robossh-*")
		if err != nil {
			return nil, fmt.Errorf("create temporary folder: %w", err)
		}
		defer func() {
			_ = os.RemoveAll(tmpDir)
		}()

		socketPath = filepath.Join(tmpDir, fmt.Sprintf("agent.%d", os.Getppid()))
	}

	keys := keystore.NewStore()
	watch, err := newWatcher(keys, flags)
	if err != nil {
		return nil, err
	}

	if watch != nil {
		go func() {
			logger.Info("certs watcher started")
			watch.Run()
		}()
	}

	instance := agent.NewAgent(keys, agent.WithKeyLifetime(keysTTL))

	shutdown := func() {
		ctx, cancel := context.WithTimeout(context.Background(), shutdownDeadline)
		defer cancel()

		instance.Stop(ctx)
		if watch != nil {
			watch.Shutdown(ctx)
		}
	}

	updateCerts := func() {
		if watch == nil {
			logger.Error("no certs required - nothing to do")
			return
		}

		ctx, cancel := context.WithTimeout(context.Background(), updateTimeout)
		defer cancel()

		err := watch.UpdateCertificates(ctx)
		if err != nil {
			logger.Error("can't update certificates", log.Error(err))
		}
	}

	listenWaitCh, err := instance.ListenAndServe(socketPath)
	if err != nil {
		return nil, fmt.Errorf("failed to start: %w", err)
	}

	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)

	waitCh := make(chan struct{})
	go func() {
		defer close(waitCh)

		for {
			select {
			case <-watchParent(ppid, flags):
				logger.Info("parent has died, will shutting down")
				shutdown()
				return
			case sig := <-sigChan:
				switch sig {
				case syscall.SIGHUP:
					logger.Info("received HUP signal, will update certificates")
					updateCerts()
				default:
					logger.Info("received TERM signal, will shutting down")
					shutdown()
					return
				}
			case <-listenWaitCh:
				return
			}
		}
	}()

	return waitCh, nil
}

func doParentWork(d *daemon.Daemon, flags startFlags) error {
	ctx, cancel := context.WithTimeout(context.Background(), daemonizeTimeout)
	defer cancel()

	child, err := d.StartChild(ctx)
	if err != nil {
		return fmt.Errorf("start child: %w", err)
	}

	var info childInfo
	if err := child.Unmarshal(&info); err != nil {
		_ = syscall.Kill(-child.PID(), syscall.SIGTERM)
		return fmt.Errorf("unmarshal child info: %w", err)
	}

	if len(flags.nArgs) == 0 {
		_, err = os.Stdout.Write(sshutil.SSHAgentScript(info.SocketPath, info.PID))
		return err
	}

	exe, err := exec.LookPath(flags.nArgs[0])
	if err != nil {
		return fmt.Errorf("can't find child binary: %w", err)
	}

	origEnvs := os.Environ()
	env := make([]string, 0, len(origEnvs))
	for _, e := range origEnvs {
		if strings.HasPrefix(e, "ROBOSSH_") {
			continue
		}
		env = append(env, e)
	}

	env = append(env,
		fmt.Sprintf("SSH_AUTH_SOCK=%s", info.SocketPath),
		fmt.Sprintf("SSH_AGENT_PID=%d", info.PID),
	)

	return syscall.Exec(exe, flags.nArgs, env)
}

func newWatcher(keys *keystore.Store, flags startFlags) (*watcher.Watcher, error) {
	if !flags.withCerts {
		return nil, nil
	}

	robocOpts := []skotty.RoboOption{
		skotty.WithClientVersion(version.Full()),
		skotty.WithRoboAuthToken(os.Getenv("ROBOSSH_TOKEN")),
	}

	if upstream := os.Getenv("ROBOSSH_UPSTREAM"); upstream != "" {
		logger.Info("will use custom upstream", log.String("upstream", upstream))
		robocOpts = append(robocOpts, skotty.WithUpstream(upstream))
	}

	certTypes := make([]skotty.CertType, 0, len(flags.certTypes))
	spottedCertType := make(map[skotty.CertType]struct{})
	for _, typ := range flags.certTypes {
		var certType skotty.CertType
		if err := certType.FromString(typ); err != nil {
			return nil, fmt.Errorf("iinvalid CA requisted: %v", err)
		}

		if _, ok := spottedCertType[certType]; ok {
			continue
		}
		certTypes = append(certTypes, certType)
	}

	certsIssuer := issuer.NewSkottyIssuer(
		issuer.WithSkottyRoboClient(skotty.NewRoboClient(robocOpts...)),
		issuer.WithCertType(certTypes...),
	)
	watch := watcher.NewWatcher(keys, watcher.WithIssuer(certsIssuer))
	ctx, cancel := context.WithTimeout(context.Background(), updateTimeout)
	err := watch.UpdateCertificates(ctx)
	cancel()
	if err != nil {
		return nil, fmt.Errorf("can't issue certificates: %w", err)
	}

	return watch, nil
}

func watchParent(ppid int, flags startFlags) <-chan struct{} {
	ch := make(chan struct{})
	if ppid == 0 || flags.foregroundMode || len(flags.nArgs) == 0 {
		return ch
	}

	go func() {
		defer close(ch)

		ticker := time.NewTicker(parentCheckPeriod)
		defer ticker.Stop()
		for {
			<-ticker.C
			if ppid != os.Getppid() {
				return
			}
		}
	}()
	return ch
}
