package main

import (
	_ "embed"
	"errors"
	"fmt"
	"log"
	"net"
	"os"

	"golang.org/x/crypto/ssh"

	"a.yandex-team.ru/security/skotty/libs/skottyca"
)

//go:embed ssh_host_rsa_key
var hostKey []byte

func fatalf(format string, args ...interface{}) {
	_, _ = fmt.Fprintf(os.Stderr, format+"\n", args...)
	os.Exit(1)
}

func main() {
	ca, err := skottyca.NewCA(skottyca.WithKind(skottyca.CAKindSecure, skottyca.CAKindInsecure))
	if err != nil {
		fatalf("unable to create skotty ca: :%v", err)
	}

	sshConfig := &ssh.ServerConfig{
		ServerVersion: "SSH-2.0-SkottyIntegrationExample",
		PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
			// Самая важная часть
			checker := ssh.CertChecker{
				// Верим только нужным CA
				IsUserAuthority: ca.IsUserAuthority,
				// С учетом отозванности сертификатов
				IsRevoked: ca.IsRevoked,
				// В случае, если к нам пришли не с сертификатом - можно замутить какой-то фоллбэк в обычный ключи
				UserKeyFallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
					return nil, errors.New("only SSH certificates are supported")
				},
			}

			// И аутентифицируем пользователя по сертификату
			return checker.Authenticate(conn, key)
		},
	}

	// Не забудьте сгенерить свои ключики :)
	private, err := ssh.ParsePrivateKey(hostKey)
	if err != nil {
		fatalf("failed to parse private key: %v", err)
	}

	sshConfig.AddHostKey(private)

	listener, err := net.Listen("tcp", ":2022")
	if err != nil {
		fatalf("failed to listen on *:2022: %v", err)
	}

	// Accept all connections
	log.Printf("listening on %s", ":2022")
	for {
		tcpConn, err := listener.Accept()
		if err != nil {
			log.Printf("failed to accept incoming connection (%s)", err)
			continue
		}

		// Before use, a handshake must be performed on the incoming net.Conn.
		sshConn, chans, reqs, err := ssh.NewServerConn(tcpConn, sshConfig)
		if err != nil {
			log.Printf("failed to handshake (%s)", err)
			continue
		}

		// Check remote address
		log.Printf("new ssh connection from %s (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion())

		go func(reqs <-chan *ssh.Request) {
			// Print incoming out-of-band Requests
			for req := range reqs {
				log.Printf("recieved out-of-band request: %+v\n", req)
			}
		}(reqs)

		// Accept all channels
		go func(chans <-chan ssh.NewChannel) {
			// Service the incoming Channel channel.
			for newChannel := range chans {
				if t := newChannel.ChannelType(); t != "session" {
					_ = newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
					continue
				}

				conn, requests, err := newChannel.Accept()
				if err != nil {
					log.Printf("could not accept channel: %v\n", err)
					continue
				}

				handleShell := func(req *ssh.Request) error {
					defer func() { _ = conn.Close() }()

					if err := req.Reply(true, nil); err != nil {
						return err
					}

					if _, err := conn.Write([]byte("good boy\r\n\r\n")); err != nil {
						return fmt.Errorf("failed to write response: %w", err)
					}

					return nil
				}

				for req := range requests {
					log.Printf("new request %s\n", req.Type)
					var err error
					switch req.Type {
					case "shell", "exec":
						err = handleShell(req)
					}

					if err != nil {
						log.Printf("failed to handle client request: %v\n", err)
						_ = req.Reply(false, nil)
						continue
					}

					_ = req.Reply(true, nil)
				}
			}
		}(chans)
	}
}
