package ussh

import (
	"a.yandex-team.ru/solomon/libs/go/hosts"
	"bufio"
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log"
	"os"
	"os/exec"
	"strings"
)

type YubikeyAction = func() error

// micro ssh client
type psshClient struct {
	yubikeyAction YubikeyAction
}

type psshConnection struct {
	client  *psshClient
	address *hosts.Address
}

func NewPSSHClient(yubikeyAction YubikeyAction) Client {
	return &psshClient{yubikeyAction: yubikeyAction}
}

func (pssh *psshClient) Connect(address *hosts.Address) (Connection, error) {
	return &psshConnection{client: pssh, address: address}, nil
}

func (pssh *psshClient) Close() {
}

func (c psshConnection) Run(cmd string) ([]byte, error) {
	var ctx = context.Background()
	command := exec.CommandContext(
		ctx,
		"pssh",
		"run",
		"--format",
		"json",
		cmd,
		c.address.Name,
	)

	stdout, err := command.StdoutPipe()
	if err != nil {
		return nil, fmt.Errorf("can't get stdout: %w", err)
	}

	stderr, err := command.StderrPipe()
	if err != nil {
		return nil, fmt.Errorf("can't get stderr: %w", err)
	}

	err = command.Start()
	if err != nil {
		return nil, fmt.Errorf("can't start pssh: %w", err)
	}

	done := make(chan error)

	go func() {
		done <- c.client.readStderr(stderr)
		close(done)
	}()

	r := &struct {
		Host       string `json:"host"`
		Stdout     string `json:"stdout"`
		Stderr     string `json:"stderr"`
		Error      string `json:"error"`
		ExitStatus int32  `json:"exit_status"`
	}{}

	if err := json.NewDecoder(stdout).Decode(r); err != nil {
		return nil, fmt.Errorf("can't unmarshal response: %w", err)
	}

	err = <-done
	if err != nil {
		return nil, fmt.Errorf("command finished with error: %w", err)
	}

	err = command.Wait()
	if err != nil {
		return nil, fmt.Errorf("pssh finished with error: %w", err)
	}

	r.Stdout = strings.TrimSpace(r.Stdout)
	r.Stderr = strings.TrimSpace(r.Stderr)

	if r.ExitStatus != 0 {
		return nil, fmt.Errorf(
			"command finished with error: %v %v",
			r.ExitStatus,
			r.Error,
		)
	}

	var b bytes.Buffer
	b.WriteString("stdout:\n")
	b.WriteString(r.Stdout)
	b.WriteString("\nstderr:\n")
	b.WriteString(r.Stderr)
	return b.Bytes(), nil
}

func (c psshConnection) Close() {
}

func (pssh *psshClient) readStdout(stdout io.Reader) ([]string, error) {
	var out []string

	reader := bufio.NewReader(stdout)

	for {
		line, err := reader.ReadString('\n')
		line = strings.TrimRight(line, "\n")

		if len(line) != 0 {
			out = append(out, line)
			log.Printf("[pssh.readStdout] '%v'\n", line)
		}
		if err == io.EOF {
			break
		}
		if err != nil {
			log.Printf("[pssh.readStdout] ReadString: %v\n", err)
			break
		}
	}

	return out, nil
}

func (pssh *psshClient) readStderr(stderr io.Reader) error {
	reader := bufio.NewReader(stderr)

	var errs []string

	for {
		line, err := reader.ReadString('\n')
		line = strings.TrimSpace(line)

		if strings.Contains(line, "You are going to be authenticated via federation-id") {
			continue
		}

		if strings.HasPrefix(line, "Issuing new session certificate") {
			const prompt = "Please touch yubikey"

			buf := make([]byte, len(prompt))
			_, err = io.ReadFull(reader, buf)

			if err != nil {
				if err == io.EOF {
					break
				}
				return fmt.Errorf("pssh.readStderr. io.ReadFull: %w", err)
			}

			if pssh.yubikeyAction != nil {
				err = pssh.yubikeyAction()
				if err != nil {
					return fmt.Errorf("[pssh.readStderr] yubikeyAction: %v", err)
				}
			}

			if pssh.yubikeyAction == nil || err != nil {
				fmt.Fprintln(os.Stderr, line)
				fmt.Fprintln(os.Stderr, prompt)
			}

			line, err = reader.ReadString('\n')
			if err != nil && err != io.EOF {
				return fmt.Errorf("pssh.readStderr. can't read OK: %w", err)
			}

			line = strings.TrimSpace(line)

			if line != "OK" {
				log.Printf("[pssh.readStderr] %v\n", line)
			}

			continue
		}

		if len(line) != 0 && !strings.HasPrefix(line, "Completed ") {
			if strings.Contains(line, "ERROR") ||
				strings.Contains(line, "Remote exited without signal") {

				errs = append(errs, line)
			}
		}

		if err == io.EOF {
			break
		}

		if err != nil {
			return fmt.Errorf("pssh.readStderr. ReadString: %w", err)
		}
	}

	if len(errs) != 0 {
		return errors.New(strings.Join(errs, "\n"))
	}

	return nil
}
