package sshrequest

import (
	"bytes"
	"fmt"
	"io"
	"io/ioutil"
	"os/exec"
	"strings"

	"golang.org/x/crypto/ssh"
)

const (
	DefaultPort = 22
	DefaultHost = "localhost"
)

type RemoteRequest struct {
	Hostname string
	Port     int
	Command  string
}

func NewRemouteRequest(hostname, cmd string) RemoteRequest {
	return RemoteRequest{Hostname: hostname, Command: cmd}
}

func (rr RemoteRequest) GetPort() int {
	if rr.Port == 0 {
		return DefaultPort
	}
	return rr.Port
}

func (rr RemoteRequest) GetHost() string {
	if len(rr.Hostname) == 0 {
		return DefaultHost
	}
	return rr.Hostname
}

func (rr RemoteRequest) GetCmd() string {
	return rr.Command
}

func prepareCommand(command string) (result []string) {
	var aggr bool
	var cache bytes.Buffer
	for i, args := range strings.Split(command, "'") {
		aggr = i%2 > 0
		if aggr {
			cache.WriteString(args)
			result = append(result, cache.String())
			cache.Truncate(cache.Len())
			continue
		}
		for _, i := range strings.Split(args, " ") {
			if len(i) == 0 {
				continue
			}
			result = append(result, i)
		}
	}
	return
}

func RemoteExecute(request RemoteRequest) (output []byte, err error) {
	if strings.Contains(request.Hostname, "localhost") {
		cmd := prepareCommand(request.Command)
		return exec.Command(cmd[0], cmd[1:]...).Output()
	} else {
		var connect *ssh.Client
		address := fmt.Sprintf("%s:%d", request.GetHost(), request.GetPort())
		if connect, err = ssh.Dial("tcp", address, PrepareSSHConnect()); err != nil {
			return nil, err
		}
		defer func() { _ = connect.Close() }()

		var session *ssh.Session
		if session, err = connect.NewSession(); err != nil {
			return nil, err
		}
		defer func() { _ = session.Close() }()

		modes := ssh.TerminalModes{
			ssh.ECHO:          0,
			ssh.TTY_OP_ISPEED: 14400,
			ssh.TTY_OP_OSPEED: 14400,
		}

		if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
			return nil, fmt.Errorf("request for pseudo terminal failed: %s", err)
		}

		var stdoutBuf io.Reader
		if stdoutBuf, err = session.StdoutPipe(); err != nil {
			return nil, fmt.Errorf("stdoutPipe error: %s", err)
		}

		fmt.Println(request)

		if err := session.Start(request.GetCmd()); err != nil {
			return nil, fmt.Errorf("failed to start shell: %s", err)
		}

		output, _ = ioutil.ReadAll(stdoutBuf)
		output = bytes.TrimRight(output, "\n")
		output = bytes.ReplaceAll(output, []byte("\r"), []byte(""))
	}
	return output, nil
}

func Wrap(err error) {
	if err != nil {
		fmt.Printf("[wrapper] error: %s\n", err)
	}
}
