package sshutil

import (
	"bytes"
	"fmt"
	"os"
	"path/filepath"
	"strconv"

	"a.yandex-team.ru/security/skotty/skotty/pkg/osutil/sysenv"
)

func SSHAgentScript(socketPath string, pid int) []byte {
	if socketPath == "" {
		return []byte{'\n'}
	}

	shell := os.Getenv("SHELL")
	switch filepath.Base(shell) {
	case "fish":
		return fishAgentScript(socketPath, pid)
	case "csh":
		return cshAgentScript(socketPath, pid)
	default:
		return bourneAgentScript(socketPath, pid)
	}
}

func bourneAgentScript(socketPath string, pid int) []byte {
	out := bytes.NewBuffer(nil)
	out.WriteString("SSH_AGENT_LAUNCHER='skotty'; export SSH_AGENT_LAUNCHER;\n")
	out.WriteString("SSH_AUTH_SOCK=")
	out.WriteString(strconv.Quote(socketPath))
	out.WriteString("; export SSH_AUTH_SOCK;")
	out.WriteRune('\n')

	if pid > 0 {
		out.WriteString("SSH_AGENT_PID=")
		out.WriteString(strconv.Itoa(pid))
		out.WriteString("; export SSH_AGENT_PID;")
		out.WriteByte('\n')
	}

	return out.Bytes()
}

func fishAgentScript(socketPath string, pid int) []byte {
	out := bytes.NewBuffer(nil)
	out.WriteString("set -gx SSH_AGENT_LAUNCHER 'skotty';\n")
	out.WriteString("set -gx SSH_AUTH_SOCK ")
	out.WriteString(strconv.Quote(socketPath))
	out.WriteString(";\n")

	if pid > 0 {
		out.WriteString("set -gx SSH_AGENT_PID ")
		out.WriteString(strconv.Itoa(pid))
		out.WriteString(";\n")
	}

	return out.Bytes()
}

func cshAgentScript(socketPath string, pid int) []byte {
	out := bytes.NewBuffer(nil)
	out.WriteString("setenv SSH_AGENT_LAUNCHER 'skotty';\n")
	out.WriteString("setenv SSH_AUTH_SOCK ")
	out.WriteString(strconv.Quote(socketPath))
	out.WriteString(";\n")

	if pid > 0 {
		out.WriteString("setenv SSH_AGENT_PID ")
		out.WriteString(strconv.Itoa(pid))
		out.WriteString(";\n")
	}

	return out.Bytes()
}

func ReplaceAuthSock(socketPath string) (bool, error) {
	curPath := os.Getenv("SSH_AUTH_SOCK")
	if curPath == "" {
		return false, nil
	}

	fi, err := os.Lstat(curPath)
	if err != nil {
		if os.IsNotExist(err) {
			return false, nil
		}
		return false, err
	}

	switch {
	case fi.Mode()&os.ModeSymlink != 0:
		// check if symlink already points to our socket
		dst, err := os.Readlink(curPath)
		if err != nil {
			return false, err
		}

		if dst == socketPath {
			return false, nil
		}
	case fi.Mode()&os.ModeSocket != 0:
		// always replace normal sockets
	default:
		return false, nil
	}

	dir, sockName := filepath.Split(curPath)
	tmpPath := filepath.Join(dir, fmt.Sprintf("%s.skotty_%d", sockName, os.Getpid()))
	if err := os.Symlink(socketPath, tmpPath); err != nil {
		return false, fmt.Errorf("create symlink: %w", err)
	}

	if err := os.Rename(tmpPath, curPath); err != nil {
		return false, fmt.Errorf("replace socket: %w", err)
	}

	return true, nil
}

func ExportAuthSock(socketPath string, pid int) error {
	envs := map[string]string{
		"SSH_AGENT_LAUNCHER": "skotty",
		"SSH_AUTH_SOCK":      socketPath,
		"SSH_AGENT_PID":      strconv.Itoa(pid),
	}

	return sysenv.Export(envs)
}
