package crash

import (
	"errors"
	"fmt"
	"io"
	"os"
	"os/exec"
	"os/signal"
	"path/filepath"
	"strings"
	"syscall"
	"time"

	"go.uber.org/atomic"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/skotty/skotty/internal/logger"
)

const (
	HistoryLimit        = 10
	RestartPeriod       = time.Second
	NeedRestartExitCode = 0x7d
)

type Saver struct {
	stdout         io.Writer
	stderr         io.Writer
	name           string
	outputDir      string
	onRestart      func() bool
	findExecutable func() (string, error)
	args           []string
}

func NewSaver(out string, opts ...Option) *Saver {
	r := &Saver{
		outputDir:      out,
		name:           "unnamed",
		onRestart:      func() bool { return true },
		findExecutable: os.Executable,
		args:           os.Args[1:],
	}

	for _, opt := range opts {
		opt(r)
	}

	return r
}

func (s *Saver) IsParent() bool {
	return os.Getenv(s.envName()) == ""
}

func (s *Saver) StartChild() (int, error) {
	exe, err := s.findExecutable()
	if err != nil {
		return 1, fmt.Errorf("can't determine executable: %w", err)
	}

	if err := os.Mkdir(s.outputDir, 0o700); err != nil && !os.IsExist(err) {
		return 1, fmt.Errorf("unable to create logs dir: %w", err)
	}

	for {
		stderrFilename := filepath.Join(s.outputDir, fmt.Sprintf("%s.stderr", s.name))
		stderr, err := os.Create(stderrFilename)
		if err != nil {
			return 1, fmt.Errorf("unable to create stderr file: %w", err)
		}

		exitCode, err := s.startChild(exe, stderr)
		_ = stderr.Close()
		if err != nil {
			return 1, err
		}

		switch exitCode {
		case NeedRestartExitCode:
			// special exit code that means "restart me"
			if !s.onRestart() {
				logger.Info("restarts externally stopped")
				return exitCode, nil
			}

			newExe, err := s.findExecutable()
			if err != nil {
				logger.Error("can't determine executable", log.Error(err))
			} else {
				exe = newExe
			}

			logger.Info("child exited with 'need restart' exit code and will be restarted",
				log.String("exe", exe))
			continue
		case 0, 1:
			// expected exit codes
			fallthrough
		case 0x40010004:
			// "force shutdown" on windows
			return exitCode, nil
		}

		if stat, err := os.Stat(stderrFilename); err == nil && stat.Size() > 0 {
			crashName := fmt.Sprintf("crash.%s.%s.stderr", s.name, time.Now().Format("2006-01-02T15-04-05"))
			savedStderrFilename := filepath.Join(s.outputDir, crashName)
			if err := os.Rename(stderrFilename, savedStderrFilename); err != nil {
				logger.Warn("unable to save crash stderr",
					log.String("from", stderrFilename),
					log.String("to", savedStderrFilename),
					log.Error(err))
			}

			logger.Error("child exited with unexpected exit code and will be restarted",
				log.Duration("sleep", RestartPeriod),
				log.Int("code", exitCode),
				log.String("stderr_file", savedStderrFilename))

			_ = s.cleanUp()
		} else {
			logger.Error("child exited with unexpected exit code and will be restarted",
				log.Duration("sleep", RestartPeriod),
				log.Int("code", exitCode))
		}

		time.Sleep(RestartPeriod)
		if !s.onRestart() {
			logger.Info("restarts externally stopped")
			return exitCode, nil
		}
	}
}

func (s *Saver) startChild(exe string, stderr io.Writer) (int, error) {
	cmd := exec.Command(exe, s.args...)
	cmd.Env = append(os.Environ(), fmt.Sprintf("%s=yes", s.envName()))
	cmd.SysProcAttr = sysProcAttr()
	cmd.Stdin = os.Stdin

	if s.stdout != nil {
		cmd.Stdout = s.stdout
	}

	if s.stderr != nil {
		cmd.Stderr = io.MultiWriter(s.stderr, stderr)
	} else {
		cmd.Stderr = stderr
	}

	if err := cmd.Start(); err != nil {
		return 1, err
	}

	signals := make(chan os.Signal, 1)
	signal.Notify(signals)
	defer func() {
		signal.Stop(signals)
		close(signals)
	}()

	var failExpected atomic.Bool
	go func() {
		for sig := range signals {
			switch sig {
			case syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL:
				failExpected.Store(true)
			}

			_ = cmd.Process.Signal(sig)
		}
	}()

	if err := cmd.Wait(); err != nil && !failExpected.Load() {
		var exitErr *exec.ExitError
		if errors.As(err, &exitErr) {
			// The program has exited with an exit code != 0
			if status, ok := exitErr.Sys().(syscall.WaitStatus); ok {
				return status.ExitStatus(), nil
			}
		}

		return 1, fmt.Errorf("unable to wait child %q: %w", cmd, err)
	}

	return 0, nil
}

func (s *Saver) cleanUp() error {
	files, err := os.ReadDir(s.outputDir)
	if err != nil {
		return fmt.Errorf("unable to list output directory: %w", err)
	}

	targetName := fmt.Sprintf("crash.%s.", s.name)
	candidates := make([]string, 0, len(files))
	for _, f := range files {
		name := f.Name()
		if !strings.HasPrefix(name, targetName) {
			continue
		}

		candidates = append(candidates, name)
	}

	if len(candidates) < HistoryLimit {
		return nil
	}

	toDelete := HistoryLimit / 2
	for _, name := range candidates {
		_ = os.RemoveAll(filepath.Join(s.outputDir, name))
		toDelete--
		if toDelete <= 0 {
			break
		}
	}

	return nil
}

func (s *Saver) envName() string {
	return fmt.Sprintf("%s_UNDER_CRASH_SAVER", strings.ToUpper(strings.ReplaceAll(s.name, "-", "_")))
}
