package main

import (
	"bytes"
	"context"
	"fmt"
	"os"
	"os/signal"
	"path/filepath"
	"strconv"
	"strings"
	"syscall"
	"time"

	"github.com/shirou/gopsutil/v3/process"
	"github.com/spf13/cobra"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/libs/go/daemon"
	"a.yandex-team.ru/security/libs/go/ioatomic"
	"a.yandex-team.ru/security/skotty/wsl2proxy/internal/config"
	"a.yandex-team.ru/security/skotty/wsl2proxy/internal/listener"
	"a.yandex-team.ru/security/skotty/wsl2proxy/internal/logger"
	"a.yandex-team.ru/security/skotty/wsl2proxy/internal/shellutil"
	"a.yandex-team.ru/security/skotty/wsl2proxy/internal/version"
)

const (
	daemonizeTimeout = 1 * time.Minute
	shutdownDeadline = 1 * time.Minute
)

var startArgs struct {
	Config     string
	Foreground bool
}

func init() {
	flags := startCmd.PersistentFlags()
	flags.StringVarP(&startArgs.Config, "config", "c", "", "config path")
	flags.BoolVarP(&startArgs.Foreground, "foreground", "f", false, "starts in the foreground mode")
}

var startCmd = &cobra.Command{
	Use:          "start",
	SilenceUsage: true,
	Short:        "starts wsl2proxy",
	RunE: func(_ *cobra.Command, _ []string) error {
		cfg, err := config.Load(startArgs.Config)
		if err != nil {
			return fmt.Errorf("unable to load config: %w", err)
		}

		if _, err := os.Stat(cfg.RuntimeDir); err != nil {
			err := os.MkdirAll(cfg.RuntimeDir, 0700)
			if err != nil {
				return fmt.Errorf("unable to create runtime dir: %w", err)
			}
		}

		if err := logger.InitLogger(cfg.LogLevel, cfg.LogPath); err != nil {
			return fmt.Errorf("setup logger: %w", err)
		}

		spec := cfg.ListenersSpec()
		pidFile := proxyPidPath(cfg.RuntimeDir)

		notifyStart := func() {
			_, _ = os.Stdout.Write(shellutil.Listeners(spec))
		}

		notifyError := func(err error) {
		}

		if checkPID(pidFile) {
			// we are already started nothing to do
			notifyStart()
			return nil
		}

		if !startArgs.Foreground {
			d := daemon.NewDaemon("wsl2proxy")
			if d.IsParent() {
				if err := startChild(d); err != nil {
					return err
				}

				notifyStart()
				return nil
			}

			notifyStart = func() {
				err := d.NotifyStarted(nil)
				if err != nil {
					logger.Error("failed to notify parent", log.Error(err))
				}
			}

			notifyError = func(err error) {
				if err := d.NotifyError(err); err != nil {
					logger.Error("failed to notify parent", log.Error(err))
				}
			}
		}

		if err := ioatomic.WriteFile(pidFile, strings.NewReader(strconv.Itoa(os.Getpid())), 0644); err != nil {
			return fmt.Errorf("unable to create pid file: %w", err)
		}

		defer func() { _ = os.Remove(pidFile) }()

		waitCh, err := doServe(spec)
		if err != nil {
			notifyError(err)
			return err
		}

		logger.Info("proxy started", log.String("version", version.Full()))
		notifyStart()

		<-waitCh
		return nil
	},
}

func startChild(d *daemon.Daemon) error {
	ctx, cancel := context.WithTimeout(context.Background(), daemonizeTimeout)
	defer cancel()

	child, err := d.StartChild(ctx)
	if err != nil {
		return fmt.Errorf("start child: %w", err)
	}

	logger.Info("child started", log.Int("pid", child.PID()))
	return nil
}

func doServe(spec listener.Spec) (<-chan struct{}, error) {
	listenCh := make(chan struct{})
	lis, err := listener.NewListener(
		listener.WithLogger(logger.L),
		listener.WithOnListen(func(_ listener.Spec) {
			close(listenCh)
		}),
	)

	if err != nil {
		return nil, fmt.Errorf("unable to create listener: %w", err)
	}

	ctx, cancel := context.WithCancel(context.Background())

	shutdown := func() {
		cancel()
		lis.Close()
	}

	serveCh := make(chan struct{})
	go func() {
		defer close(serveCh)

		if err := lis.ListenAndServe(ctx, spec); err != nil {
			logger.Error("listen failed", log.Error(err))
		}
	}()

	<-listenCh
	logger.Info("all sockets are being listened")

	sigChan := make(chan os.Signal, 1)
	signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

	waitCh := make(chan struct{})
	go func() {
		defer close(waitCh)

		for {
			select {
			case <-sigChan:
				logger.Info("received TERM signal, will shutting down")
				shutdown()

				select {
				case <-time.After(shutdownDeadline):
					logger.Warn("can't stop in shutdown deadline")
				case <-serveCh:
				}
				return
			case <-serveCh:
				return

			}
		}
	}()

	return waitCh, nil
}

func checkPID(pidFile string) bool {
	rawPid, err := os.ReadFile(pidFile)
	if err != nil {
		return false
	}

	rawPid = bytes.TrimSpace(rawPid)
	pid, err := strconv.Atoi(string(rawPid))
	if err != nil {
		return false
	}

	ok, _ := process.PidExists(int32(pid))
	return ok
}

func proxyPidPath(runtimeDir string) string {
	return filepath.Join(runtimeDir, "proxy.pid")
}
