package agent

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"os"

	"golang.org/x/crypto/ssh/agent"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/skotty/libs/netutil"
	"a.yandex-team.ru/security/skotty/libs/sshagent"
	"a.yandex-team.ru/security/skotty/robossh/internal/keystore"
	"a.yandex-team.ru/security/skotty/robossh/internal/logger"
)

type Agent struct {
	handler   agent.ExtendedAgent
	ctx       context.Context
	cancelCtx context.CancelFunc
	stopped   chan struct{}
}

type Creds struct {
	UID  int
	PID  int
	Comm string
}

type Conn struct {
	net.Conn
	Creds Creds
}

func NewAgent(keys *keystore.Store, opts ...Option) *Agent {
	h := &Handler{
		keys: keys,
	}

	for _, opt := range opts {
		switch o := opt.(type) {
		case lifetimeOption:
			h.lifetime = o.lifetime
		}
	}

	ctx, cancel := context.WithCancel(context.Background())
	return &Agent{
		handler:   h,
		ctx:       ctx,
		cancelCtx: cancel,
		stopped:   make(chan struct{}),
	}
}

func (a *Agent) ListenAndServe(socketPath string) (<-chan struct{}, error) {
	defer close(a.stopped)

	ln, err := newListener(socketPath)
	if err != nil {
		return nil, fmt.Errorf("create listener: %w", err)
	}

	conns := make(chan net.Conn)
	go func() {
		for {
			c, err := ln.Accept()
			switch {
			case err == nil:
				conns <- c
			case errors.Is(err, net.ErrClosed), errors.Is(err, io.ErrClosedPipe):
				return
			default:
				logger.Warn("could not accept connection to agent", log.Error(err))
			}
		}
	}()

	waitCh := make(chan struct{})
	go func() {
		defer close(waitCh)

		for {
			select {
			case <-a.ctx.Done():
				_ = ln.Close()
				_ = os.RemoveAll(socketPath)
				return
			case c := <-conns:
				go a.serveConn(c)
			}
		}
	}()

	return waitCh, nil
}

func (a *Agent) serveConn(conn net.Conn) {
	defer func() {
		if rv := recover(); rv != nil {
			logger.Error("panic recovered", log.Any("msg", rv))
			return
		}
	}()

	defer func() { _ = conn.Close() }()

	creds, err := netutil.UnixSocketCreds(conn)
	if err != nil {
		logger.Error("unable to determine client creds", log.Error(err))
		return
	}

	if err := checkCreds(&creds); err != nil {
		logger.Warn(err.Error())
		return
	}

	handler := sshagent.LoggableHandler{
		Handler: a.handler,
		Log:     log.With(logger.L, credLogFields(&creds)...),
	}
	if err := agent.ServeAgent(&handler, conn); !errors.Is(err, io.EOF) {
		logger.Error("agent client connection ended with error", log.Error(err))
	}
}

// Stop and clean up SSH agent
func (a *Agent) Stop(ctx context.Context) {
	a.cancelCtx()

	select {
	case <-ctx.Done():
	case <-a.stopped:
	}
}

func newListener(socketPath string) (net.Listener, error) {
	tmpPath := fmt.Sprintf("%s.%d", socketPath, os.Getpid())
	ln, err := net.Listen("unix", tmpPath)
	if err != nil {
		return nil, fmt.Errorf("listen: %w", err)
	}

	if err := os.Rename(tmpPath, socketPath); err != nil {
		_ = ln.Close()
		return nil, fmt.Errorf("rename socket %q -> %q: %w", tmpPath, socketPath, err)
	}

	return ln, nil
}
