package listener

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"os"
	"path/filepath"
	"sync"

	"golang.org/x/sync/errgroup"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/nop"
	"a.yandex-team.ru/security/skotty/wsl2proxy/internal/pipeproxy"
)

type Listener struct {
	proxy      *pipeproxy.PipeProxy
	listeners  []net.Listener
	forwards   map[string]*pipeproxy.Session
	onListenFn func(spec Spec)
	listenWg   sync.WaitGroup
	mu         sync.Mutex
	log        log.Logger
}

func NewListener(opts ...Option) (*Listener, error) {
	out := &Listener{
		log:        &nop.Logger{},
		forwards:   make(map[string]*pipeproxy.Session),
		onListenFn: func(_ Spec) {},
	}

	var proxyBinary string
	var proxyArgs []string
	for _, opt := range opts {
		switch v := opt.(type) {
		case loggerOption:
			out.log = v.log
		case proxyBinaryOption:
			proxyBinary, proxyArgs = v.binary, v.args
		case onListenOption:
			out.onListenFn = v.onListenFn

		default:
			return nil, fmt.Errorf("unsuppported option: %T", v)
		}
	}

	if proxyBinary == "" {
		// let's assume it
		exec, err := os.Executable()
		if err != nil {
			return nil, fmt.Errorf("unable to determine current executable: %w", err)
		}

		proxyBinary = exec + ".exe"
		proxyArgs = []string{"forward"}
	}

	out.proxy = pipeproxy.NewPipeProxy(proxyBinary, proxyArgs...)
	return out, nil
}

func (l *Listener) ListenAndServe(ctx context.Context, spec Spec) error {
	ctx, cancel := context.WithCancel(ctx)
	defer cancel()

	g, ctx := errgroup.WithContext(ctx)
	for i, pair := range spec.Pairs {
		if pair.Src == "" {
			return fmt.Errorf("empty src for %d pair in spec", i)
		}

		if pair.Dst == "" {
			return fmt.Errorf("empty dst for %d pair in spec", i)
		}

		if pair.Src == pair.Dst {
			return fmt.Errorf("invalid sockets pair: %s (src) == %s (dst)", pair.Src, pair.Dst)
		}

		fn, err := l.listenFn(ctx, pair)
		if err != nil {
			return err
		}

		g.Go(fn)
	}

	l.onListenFn(spec)
	return g.Wait()
}

func (l *Listener) listenFn(ctx context.Context, pair ListenPair) (func() error, error) {
	_ = os.Remove(pair.Src)
	err := os.MkdirAll(filepath.Dir(pair.Src), 0o700)
	if err != nil {
		return nil, fmt.Errorf("failed to create unix socket folder: %w", err)
	}

	udsListener, err := net.Listen("unix", pair.Src)
	if err != nil {
		return nil, err
	}

	logger := log.With(l.log, log.String("src", pair.Src), log.String("dst", pair.Dst))
	l.listeners = append(l.listeners, udsListener)
	l.listenWg.Add(1)
	return func() error {
		defer func() {
			_ = os.Remove(pair.Src)
			l.listenWg.Done()
		}()

		for {
			conn, err := udsListener.Accept()
			if err != nil {
				if errors.Is(err, net.ErrClosed) {
					logger.Info("connection closed")
					return nil
				}

				logger.Warn("unable to accept incoming connection", log.Error(err))
				continue
			}

			proxy, err := l.getProxy(ctx, pair.Dst)
			if err != nil {
				logger.Warn("unable to create proxy", log.Error(err))
				_ = conn.Close()
				continue
			}

			go func() {
				if err := pipeproxy.Proxy(proxy, conn); err != nil {
					logger.Warn("proxy failed", log.Error(err))
				}
			}()
		}
	}, nil
}

func (l *Listener) getProxy(ctx context.Context, addr string) (io.ReadWriteCloser, error) {
	l.mu.Lock()
	defer l.mu.Unlock()

	sess, sessExists := l.forwards[addr]
	if sessExists {
		if sess.IsAlive() {
			return sess.OpenStream()
		}

		_ = sess.Close()
		delete(l.forwards, addr)
	}

	sess, err := l.proxy.Start(ctx, addr)
	if err != nil {
		return nil, fmt.Errorf("unable to start proxy: %w", err)
	}
	l.forwards[addr] = sess
	return sess.OpenStream()
}

func (l *Listener) Close() {
	for _, l := range l.listeners {
		_ = l.Close()
		l.Addr()
	}
}
