package ssh

import (
	"fmt"
	"log"
	"net"
	"os"
	"regexp"
	"strconv"
	"strings"
	"time"

	"code.justin.tv/release/courier/pkg/structs"

	"bytes"

	"code.justin.tv/dta/rtevent"
	"golang.org/x/crypto/ssh"
)

const DefaultNumDialRetries = 3
const DefaultSshDialTimeout = 20

var (
	rtclient    *rtevent.Client
	timestampRe *regexp.Regexp

	// Supported SSH client configuration per https://twitchtv.atlassian.net/wiki/display/SEC/Configuring+OpenSSH
	sshConfig = ssh.Config{
		KeyExchanges: strings.Split("curve25519-sha256@libssh.org,ecdh-sha2-nistp521,ecdh-sha2-nistp256,ecdh-sha2-nistp384,diffie-hellman-group-exchange-sha256", ","),
		// Golang only supports a limited set of ciphers, these are the ciphers that meet our requirements
		Ciphers: strings.Split("aes256-ctr,aes192-ctr,aes128-ctr", ","),
		// Golang only supports a limited set of macs, these are the macs that meet our requirements
		MACs: strings.Split("hmac-sha2-256", ","),
	}
	hostKeys = strings.Split("ssh-ed25519-cert-v01@openssh.com,ssh-rsa-cert-v01@openssh.com,ssh-ed25519,ssh-rsa,ecdsa-sha2-nistp521-cert-v01@openssh.com,ecdsa-sha2-nistp384-cert-v01@openssh.com,ecdsa-sha2-nistp256-cert-v01@openssh.com,ecdsa-sha2-nitp521,ecdsa-sha2-nistp384,ecdsa-sha2-nistp256", ",")
)

var (
	sharedAgent = &Agent{}
)

func init() {
	// timestampRe is designed to match the default timestamp output by
	// log.Print* this is then used to strip the leading timestamp from the
	// output of copies of remotely run couriers.
	timestampRe = regexp.MustCompile(`\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2} `)

	sshConfig.SetDefaults()
}

func NewClient(user, host string, port int, options *structs.Options) (*ssh.Client, error) {
	var err error

	retries := options.Retries
	if retries < 1 {
		retries = DefaultNumDialRetries
	}

	if err = sharedAgent.Open(); err != nil {
		return nil, err
	}
	sshConfig, err := sshClientConfig(user, options)
	if err != nil {
		return nil, err
	}
	for i := 0; i < retries; i++ {
		var client *ssh.Client
		client, err = ssh.Dial(
			"tcp",
			net.JoinHostPort(host, strconv.Itoa(port)),
			sshConfig,
		)
		if err == nil {
			return client, nil
		}
	}

	return nil, fmt.Errorf("error connecting to ssh: %v", err)
}

// Given basic ssh arguments, return the number of retries and a client configuration to pass to the SSH library
func sshClientConfig(user string, options *structs.Options) (*ssh.ClientConfig, error) {
	if user == "" {
		user = "deploy"
	}

	sshDialTimeout := options.SshDialTimeout
	if sshDialTimeout < 1 {
		sshDialTimeout = DefaultSshDialTimeout
	}

	auths := []ssh.AuthMethod{}
	if sharedAgent.IsSetup() {
		auths = []ssh.AuthMethod{ssh.PublicKeysCallback(sharedAgent.Signers)}
	} else {
		return nil, fmt.Errorf("ssh-agent must be opened prior to call")
	}
	config := &ssh.ClientConfig{
		Config:            sshConfig,
		User:              user,
		Auth:              auths,
		HostKeyAlgorithms: hostKeys,
		HostKeyCallback:   ssh.InsecureIgnoreHostKey(),
		Timeout:           time.Duration(sshDialTimeout) * time.Second,
	}
	return config, nil
}

func NewSession(user, host string, options *structs.Options) (*ssh.Session, error) {
	h, p := splitPort(host)
	ssh, err := NewClient(user, h, p, options)
	if err != nil {
		return nil, err
	}

	return ssh.NewSession()
}

func splitPort(host string) (string, int) {
	port := 22

	parts := strings.Split(host, ":")
	if len(parts) > 1 {
		host = parts[0]
		p, err := strconv.Atoi(parts[1])
		if err != nil {
			return host, port
		}
		port = p
	}

	return host, port
}

func RunRemoteCommand(session *ssh.Session, cmd string) ([]byte, error) {
	proxy := ""
	noProxy := ""

	if os.Getenv("HTTP_PROXY") != "" {
		proxy = os.Getenv("HTTP_PROXY")
	}

	if os.Getenv("http_proxy") != "" {
		proxy = os.Getenv("http_proxy")
	}

	if os.Getenv("NO_PROXY") != "" {
		noProxy = os.Getenv("NO_PROXY")
	}

	if os.Getenv("no_proxy") != "" {
		noProxy = os.Getenv("no_proxy")
	}

	return session.CombinedOutput(fmt.Sprintf("HTTP_PROXY=%v NO_PROXY=%v %v", proxy, noProxy, cmd))
}

func MakeCommandString(cmd string, flags map[string]*string) string {
	if flags == nil {
		return cmd

	}
	return fmt.Sprintf("%v%v", cmd, convertFlagsToParameterString(flags))
}

func convertFlagsToParameterString(flags map[string]*string) string {
	var buffer bytes.Buffer
	for k, v := range flags {
		buffer.WriteString(fmt.Sprintf(" --%v", k))
		if v != nil {
			buffer.WriteString(fmt.Sprintf(" %q", *v))
		}
	}
	return buffer.String()
}

func CleanRemoteOutput(out []byte) []string {
	lines := strings.Split(strings.TrimSpace(string(out)), "\n")
	return lines
}

type SshCommandRunner struct{}

// RunRemoteHost takes a cmd and flags and runs it on the specified host via ssh
// Any error encountered from ssh or the command will be returned
func (scr SshCommandRunner) RunRemoteHost(host string, cmd string, options *structs.Options, flags map[string]*string) error {
	var err error
	const timeFormat = "2006/01/02 15:04:05"
	success := false
	if options.SkadiID > 0 {
		if rtclient == nil {
			if rtclient, err = rtevent.NewPublisherClient(); err != nil {
				log.Printf("Failed to create rtevent publisher client - %v", err)
			}
		}

		if rtclient != nil {
			defer func() {
				sha, _ := options.GetDeployedVersion()
				event := &rtevent.DeployEvent{
					App:         options.Repo,
					Sha:         sha,
					Environment: options.Environment,
					Success:     success,
					Phase:       "install-remote",
					Deployer:    "skadi",
					DeployID:    options.SkadiID,
					Desc:        MakeCommandString(cmd, flags),
				}
				if strings.Contains(cmd, "install") {
					event.Phase = "install-remote"
				} else if strings.Contains(cmd, "restart") {
					event.Phase = "restart-remote"
				} else {
					event.Phase = "unknown-remote"
				}
				if success == false && err != nil {
					event.Desc = fmt.Sprintf("%v - %v", err, MakeCommandString(cmd, flags))
				}
				if err := rtclient.SendDeployEventOnbehalf(event, host); err != nil {
					log.Printf("Failed to send rtevent - %v, %+v, %+v", err, *event, *options)
				}
			}()
		}
	}

	cmdString := MakeCommandString(cmd, flags)
	log.Printf("%s [%v] Running: %v", time.Now().Format(timeFormat), host, cmdString)
	session, err := NewSession("deploy", host, options)
	if err != nil {
		log.Printf("%s [%v] %v", time.Now().Format(timeFormat), host, err)
		return err
	}
	defer session.Close()

	buf, err := RunRemoteCommand(session, cmdString)
	lines := CleanRemoteOutput(buf)
	for _, line := range lines {
		log.Print(fmt.Sprintf("[%v] %v", host, line))
	}
	if err != nil {
		log.Printf("%s [%v] %v", time.Now().Format(timeFormat), host, err)
		return err
	}

	success = true
	return nil
}
