package main

import (
	"errors"
	"fmt"
	"io"
	"os"
	"strings"

	"github.com/spf13/cobra"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/skotty/wsl2proxy/internal/logger"
	"a.yandex-team.ru/security/skotty/wsl2proxy/internal/pipeproxy"
)

var forwardCmd = &cobra.Command{
	Use:          "forward pipe-to-forward",
	SilenceUsage: true,
	Short:        "forward stdin/stdout to pipe",
	RunE: func(_ *cobra.Command, args []string) error {
		if len(args) != 1 {
			return errors.New("must provide pipe to forward")
		}

		targetAddr := args[0]

		if err := logger.InitLogger(log.InfoLevel, "stderr"); err != nil {
			return fmt.Errorf("setup logger: %w", err)
		}

		conn := &pipeproxy.PipeConn{
			Reader: os.Stdin,
			Writer: os.Stdout,
		}
		defer func() { _ = conn.Close() }()

		session, err := pipeproxy.NewSession(conn)
		if err != nil {
			return fmt.Errorf("unable to create pipe session: %w", err)
		}

		for {
			stream, err := session.AcceptStream()
			if err != nil {
				if errors.Is(err, io.EOF) {
					logger.Warn("stream closed")
					return nil
				}

				logger.Warn("unable to accept stream", log.Error(err))
				continue
			}

			go forwardConn(stream, targetAddr)
		}
	},
}

func forwardConn(stream io.ReadWriteCloser, addr string) {
	defer func() { _ = stream.Close() }()

	scheme, addr, _ := strings.Cut(addr, ":")
	conn, err := dialAddr(scheme, addr)
	if err != nil {
		logger.Error("unable to dial pipe", log.Error(err))
		return
	}
	defer func() { _ = conn.Close() }()

	if err := pipeproxy.Proxy(stream, conn); err != nil {
		logger.Warn("proxy failed", log.Error(err))
	}
}
