package supervisor

import (
	"context"
	"errors"
	"fmt"
	"net"

	"golang.org/x/sys/unix"
	"google.golang.org/grpc/credentials"
)

var _ credentials.TransportCredentials = (*unixSocketCredentials)(nil)

type unixSocketCredentials struct {
	uid        int
	gid        int
	serverName string
}

func newUnixSocketCredentials(uid, gid int) *unixSocketCredentials {
	return &unixSocketCredentials{
		uid:        uid,
		gid:        gid,
		serverName: "locahost",
	}
}

func (u *unixSocketCredentials) ClientHandshake(_ context.Context, _ string, _ net.Conn) (net.Conn, credentials.AuthInfo, error) {
	return nil, nil, errors.New("ClientHandshake is not supported by unixSocketCredentials")
}

func (u *unixSocketCredentials) ServerHandshake(c net.Conn) (net.Conn, credentials.AuthInfo, error) {
	uc, ok := c.(*net.UnixConn)
	if !ok {
		return nil, nil, errors.New("unixSocketCredentials only supports unix socket")
	}

	f, err := uc.File()
	if err != nil {
		return nil, nil, fmt.Errorf("unixSocketCredentials: failed to retrieve connection underlying fd: %w", err)
	}

	pcred, err := unix.GetsockoptUcred(int(f.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED)
	if err != nil {
		return nil, nil, fmt.Errorf("unixSocketCredentials: failed to retrieve socket peer credentials: %w", err)
	}

	if (u.uid != -1 && uint32(u.uid) != pcred.Uid) || (u.gid != -1 && uint32(u.gid) != pcred.Gid) {
		return nil, nil, errors.New("unixSocketCredentials: invalid credentials")
	}

	return c, u, nil
}

func (u *unixSocketCredentials) Info() credentials.ProtocolInfo {
	return credentials.ProtocolInfo{
		SecurityProtocol: "unix-socket-peer-creds",
		ServerName:       u.serverName,
	}
}

func (u *unixSocketCredentials) Clone() credentials.TransportCredentials {
	return &unixSocketCredentials{
		uid:        u.uid,
		gid:        u.gid,
		serverName: u.serverName,
	}
}

func (u *unixSocketCredentials) OverrideServerName(serverName string) error {
	u.serverName = serverName
	return nil
}

func (u *unixSocketCredentials) AuthType() string {
	return "unix-socket-peer-creds"
}
