package main

import (
	"bytes"
	"crypto/rand"
	"errors"
	"fmt"
	"io"
	"log/syslog"
	"net"
	"os"
	"runtime"
	"strings"

	"github.com/stripe/krl"
	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"
	"google.golang.org/protobuf/proto"

	"a.yandex-team.ru/security/skotty/skotty/pkg/psudo"
)

var (
	defaultCAPath  = "/etc/ssh/trusted_sudo_ca"
	defaultKRLPath = "/etc/ssh/revoked_keys"
)

type SudoRsp struct {
	Nonce     []byte
	Cert      *ssh.Certificate
	Signature *ssh.Signature
}

func pamLog(format string, args ...interface{}) {
	l, err := syslog.New(syslog.LOG_AUTH|syslog.LOG_WARNING, "skotty-sudo")
	if err != nil {
		return
	}

	_ = l.Warning(fmt.Sprintf(format, args...))
}

func userCall(uid int, fn func() error) error {
	origEUID := os.Geteuid()
	if os.Getuid() != origEUID || origEUID == 0 {
		// Note: this only sets the euid and doesn't do anything with the egid.
		// That should be fine for most cases, but it's worth calling out.
		if !seteuid(uid) {
			pamLog("error dropping privs from %d to %d", origEUID, uid)
			return errors.New("no env[SSH_AUTH_SOCK]")
		}
		defer func() {
			if !seteuid(origEUID) {
				pamLog("error resetting uid to %d", origEUID)
			}
		}()
	}

	return fn()
}

func agentSock(uid int) (net.Conn, error) {
	authSock := os.Getenv("SSH_AUTH_SOCK")
	if authSock == "" {
		return nil, errors.New("no env[SSH_AUTH_SOCK]")
	}

	var out net.Conn
	err := userCall(uid, func() (err error) {
		out, err = net.Dial("unix", authSock)
		return
	})
	if err != nil {
		return nil, fmt.Errorf("dial ssh-agent fail: %w", err)
	}

	return out, nil
}

func callAgent(sock net.Conn) (*SudoRsp, error) {
	newReq := func() (*psudo.SudoReq, error) {
		hostname, err := os.Hostname()
		if err != nil {
			return nil, fmt.Errorf("can't get hostname: %w", err)
		}

		req := psudo.SudoReq{
			Hostname: hostname,
			Nonce:    make([]byte, 32),
		}

		if _, err := io.ReadFull(rand.Reader, req.Nonce); err != nil {
			return nil, fmt.Errorf("failed to generate nonce: %w", err)
		}

		return &req, nil
	}

	req, err := newReq()
	if err != nil {
		return nil, err
	}

	skottyAgent := agent.NewClient(sock)
	_, _ = os.Stderr.WriteString("Touch yubikey to authenticate...\n")

	reqBytes, err := proto.Marshal(req)
	if err != nil {
		return nil, fmt.Errorf("can't call skotty sudo extension: %w", err)
	}

	rspBytes, err := skottyAgent.Extension("skotty-sudo", reqBytes)
	if err != nil {
		return nil, fmt.Errorf("can't call skotty sudo extension: %w", err)
	}

	var rsp psudo.SudoRsp
	err = proto.Unmarshal(rspBytes, &rsp)
	if err != nil {
		return nil, fmt.Errorf("failed to read skotty response: %w", err)
	}

	if len(req.Nonce) < 32 {
		return nil, fmt.Errorf("invalid nonce length: %d < 32", len(req.Nonce))
	}

	if len(rsp.PubKey) == 0 {
		return nil, errors.New("no pubkey")
	}

	userPubKey, err := ssh.ParsePublicKey(rsp.PubKey)
	if err != nil {
		return nil, fmt.Errorf("failed to parse user pub key: %w", err)
	}

	userCert, ok := userPubKey.(*ssh.Certificate)
	if !ok {
		return nil, errors.New("invalid user pub ey")
	}

	return &SudoRsp{
		Nonce: req.Nonce,
		Cert:  userCert,
		Signature: &ssh.Signature{
			Format: rsp.Signature.Format,
			Blob:   rsp.Signature.Blob,
			Rest:   rsp.Signature.Rest,
		},
	}, nil
}

func checkSign(signedRsp *SudoRsp, username, caPath string) error {
	caBytes, err := os.ReadFile(caPath)
	if err != nil {
		return fmt.Errorf("failed to read ca: %w", err)
	}

	var caPubs []ssh.PublicKey
	in := caBytes
	for {
		pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(in)
		if err != nil {
			pamLog("skipping bad public key: %v", err)
		} else {
			caPubs = append(caPubs, pubKey)
		}

		if len(rest) == 0 {
			break
		}
		in = rest
	}

	checker := &ssh.CertChecker{
		IsUserAuthority: func(auth ssh.PublicKey) bool {
			for _, k := range caPubs {
				if bytes.Equal(auth.Marshal(), k.Marshal()) {
					return true
				}
			}
			return false
		},
	}

	if signedRsp.Cert.CertType != ssh.UserCert {
		return fmt.Errorf("cert has unsupported type %d", signedRsp.Cert.CertType)
	}

	if !checker.IsUserAuthority(signedRsp.Cert.SignatureKey) {
		return errors.New("certificate signed by unrecognized authority")
	}

	if err := checker.CheckCert(username, signedRsp.Cert); err != nil {
		return fmt.Errorf("user certificate verification failed: %w", err)
	}

	if err := signedRsp.Cert.Verify(signedRsp.Nonce, signedRsp.Signature); err != nil {
		return fmt.Errorf("signature verification failed: %w", err)
	}

	return nil
}

func checkRevoked(signedRsp *SudoRsp, krlPath string) error {
	krlData, err := os.ReadFile(krlPath)
	if err != nil {
		pamLog("can't check KRL: read krl file: %v", err)
		return nil
	}

	parsedKRL, err := krl.ParseKRL(krlData)
	if err != nil {
		pamLog("can't check KRL: invalid krl: %v", err)
		return nil
	}

	if parsedKRL.IsRevoked(signedRsp.Cert) {
		return fmt.Errorf("certificate %s revoked by file %s", ssh.FingerprintSHA256(signedRsp.Cert), krlPath)
	}
	return nil
}

func authenticate(ph *PamHandle, caPath, krlPath string) error {
	sock, err := agentSock(ph.UID())
	if err != nil {
		return err
	}
	defer func() { _ = sock.Close() }()

	rsp, err := callAgent(sock)
	if err != nil {
		return err
	}

	err = checkSign(rsp, ph.User(), caPath)
	if err != nil {
		return err
	}

	if krlPath != "" {
		if err := checkRevoked(rsp, krlPath); err != nil {
			return err
		}
	}

	pamLog("Authentication succeeded for %q (cert %q, %d)", ph.User(), rsp.Cert.ValidPrincipals[0], rsp.Cert.Serial)
	return nil
}

func pamAuthenticate(ph *PamHandle, argv []string) error {
	runtime.GOMAXPROCS(1)

	caPath := defaultCAPath
	krlPath := defaultKRLPath
	for _, arg := range argv {
		opt := strings.Split(arg, "=")
		if len(opt) != 2 {
			pamLog("invalid option: %s", arg)
			continue
		}

		switch opt[0] {
		case "ca_file":
			caPath = opt[1]
			pamLog("ca_file set to: %s", caPath)
		case "krl_file":
			krlPath = opt[1]
			pamLog("krl_file set to: %s", caPath)
		default:
			pamLog("unsupported option: %s", opt[0])
		}
	}

	return authenticate(ph, caPath, krlPath)
}

func main() {}
