package askpass

import (
	"bytes"
	"errors"
	"fmt"
	"os"
	"os/exec"
	"syscall"
)

type programResolver func() (string, error)

type Client struct {
	pr programResolver
}

func NewClient(opts ...ClientOption) *Client {
	out := &Client{
		pr: findAskPassProg,
	}

	for _, opt := range opts {
		opt(out)
	}

	return out
}

func (c *Client) Confirm(msg string) (bool, error) {
	if c.isAskPassDisabled() {
		return false, errors.New("ask-pass is disabled")
	}

	exitCode, stdout, err := c.runAskProgram(msg, "confirm")
	if err != nil {
		return false, err
	}

	if exitCode != 0 {
		return false, nil
	}

	// https://github.com/openssh/openssh-portable/blob/8a0848cdd3b25c049332cd56034186b7853ae754/readpass.c#L212-L219
	/*
	 * Accept empty responses and responses consisting
	 * of the word "yes" as affirmative.
	 */
	if len(stdout) == 0 || bytes.EqualFold(stdout, []byte{'y', 'e', 's'}) {
		return true, nil
	}

	return false, nil
}

func (c *Client) isAskPassDisabled() bool {
	return os.Getenv("SSH_ASKPASS_REQUIRE") == "never"
}

func (c *Client) runAskProgram(msg, hint string) (int, []byte, error) {
	prog, err := c.pr()
	if err != nil {
		return 1, nil, fmt.Errorf("unable to find ask-program: %w", err)
	}

	cmd := exec.Command(prog, msg)
	var stdout, stdin, stderr bytes.Buffer
	cmd.Stdin = &stdin
	cmd.Stdout = &stdout
	cmd.Stderr = &stderr

	cmd.Env = os.Environ()
	if hint != "" {
		cmd.Env = append(cmd.Env, "SSH_ASKPASS_PROMPT="+hint)
	}

	if err := cmd.Run(); err != nil {
		if exitErr, ok := err.(*exec.ExitError); ok {
			// The program has exited with an exit code != 0
			if status, ok := exitErr.Sys().(syscall.WaitStatus); ok {
				return status.ExitStatus(), nil, nil
			}
		}

		return 1, nil, fmt.Errorf("failed to execute ssh-askpass %q: %w", cmd, err)
	}

	return 0, bytes.TrimSpace(stdout.Bytes()), nil
}
