package agent

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"time"

	"golang.org/x/crypto/ssh/agent"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/library/go/core/log/ctxlog"
	"a.yandex-team.ru/security/skotty/skotty/internal/socket"
)

type listener struct {
	handler SSHAgent
	kind    socket.Kind
	uid     int
	addr    string
	log     log.Logger
	ctx     context.Context
	closeFn context.CancelFunc
	done    chan struct{}
}

func (l *listener) ListenAndServe(ctx context.Context) error {
	l.done = make(chan struct{})
	defer close(l.done)

	l.ctx, l.closeFn = context.WithCancel(ctx)

	var sock socket.Socket
	var err error
	switch l.kind {
	case socket.KindUnix:
		sock, err = socket.NewUnixSocket(l.addr)
	case socket.KindCygwin:
		sock, err = socket.NewCygwinSocket(l.addr)
	case socket.KindPipe:
		sock, err = socket.NewPipeSocket(l.addr)
	case socket.KindPageantWindow:
		sock, err = socket.NewPageantWindow()
	case socket.KindPageantPipe:
		sock, err = socket.NewPageantPipe()
	case socket.KindDummy:
		sock = socket.NewDummySocket()
	default:
		return fmt.Errorf("unsupported socket kind: %s", l.kind)
	}

	if err != nil {
		return fmt.Errorf("failed to create socket of kind %q: %w", l.kind, err)
	}

	if err := sock.Listen(); err != nil {
		return fmt.Errorf("listen socket of kind %q failed: %w", l.kind, err)
	}
	defer func() { _ = sock.Close() }()

	l.log.Info("listen", log.String("kind", l.kind.String()), log.String("addr", sock.Addr()))
	defer l.log.Info("agent stopped")

	for {
		c, err := sock.Accept(l.ctx)
		if err != nil {
			if errors.Is(err, net.ErrClosed) || errors.Is(err, io.ErrClosedPipe) || l.ctx.Err() != nil {
				return nil
			}

			type temporary interface {
				Temporary() bool
			}

			if t, ok := err.(temporary); ok && t.Temporary() {
				wait := 1 * time.Second
				l.log.Warn("temporary accept error", log.Error(err), log.Duration("wait", wait))
				time.Sleep(wait)
			}

			l.log.Error("failed to accept connection", log.Error(err))
			continue
		}

		go l.serveConn(c)
	}
}

func (l *listener) Shutdown(ctx context.Context) {
	if l.closeFn == nil {
		return
	}

	l.closeFn()
	select {
	case <-l.done:
	case <-ctx.Done():
		l.log.Error("failed to close listener", log.Error(ctx.Err()))
	}
}

func (l *listener) serveConn(conn socket.Conn) {
	defer func() {
		if rv := recover(); rv != nil {
			l.log.Error("panic recovered", log.Any("msg", rv))
			return
		}
	}()

	defer func() { _ = conn.Close() }()

	peer := conn.Peer()
	if peer.UID != -1 && peer.UID != 0 && peer.UID != l.uid {
		l.log.Error("connection from another user (except root) is prohibited",
			log.Int("agent_uid", l.uid),
			log.Any("peer", peer))
		return
	}

	handler := &loggableHandler{
		ctx: ctxlog.WithFields(context.Background(),
			log.Int("peer_uid", peer.UID),
			log.Int("peer_pid", peer.PID),
			log.String("peer_name", peer.Name),
			log.String("peer_cmdline", peer.Cmd),
		),
		log:     l.log,
		handler: l.handler,
	}

	err := agent.ServeAgent(handler, conn)
	if !errors.Is(err, io.EOF) {
		l.log.Error("agent client connection ended with error", log.Error(err))
	}
}
