package ussh

import (
	"a.yandex-team.ru/solomon/libs/go/hosts"
	"fmt"
	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"
	"log"
	"net"
	"os"
	"strconv"
	"time"
)

const (
	sshConnectTimeout = 5 * time.Second
	sshPort           = 22
	bastionAddr       = "bastion.cloud.yandex.net"
)

// micro ssh client
type sshClient struct {
	config        *ssh.ClientConfig
	bastionClient *ssh.Client
	sshAgent      agent.ExtendedAgent
}

func NewSSHClient(useBastion bool) Client {
	config := &ssh.ClientConfig{
		User:            os.Getenv("USER"),
		Auth:            []ssh.AuthMethod{},
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
		Timeout:         sshConnectTimeout,
	}
	config.SetDefaults()

	if auth := newPublicKeysAuth(); auth != nil {
		config.Auth = append(config.Auth, auth)
	}

	sshAgent := newSSHAgent()
	if sshAgent != nil {
		config.Auth = append(config.Auth, ssh.PublicKeysCallback(sshAgent.Signers))
	}

	if useBastion {
		bastionClient, err := ssh.Dial("tcp", bastionAddr+":"+strconv.Itoa(sshPort), config)
		if err != nil {
			log.Fatalf("cannot connect to cloud bastion: %v", err)
		}
		return &sshClient{config: config, bastionClient: bastionClient, sshAgent: sshAgent}
	}

	return &sshClient{config: config, bastionClient: nil, sshAgent: sshAgent}
}

func addrWithPort(address *hosts.Address, port int) string {
	if len(address.IP) == net.IPv6len {
		return "[" + address.IP.String() + "]:" + strconv.Itoa(port)
	}
	return address.IP.String() + ":" + strconv.Itoa(port)
}

func (c *sshClient) Connect(address *hosts.Address) (Connection, error) {
	addr := addrWithPort(address, sshPort)
	if c.bastionClient != nil {
		conn, err := c.bastionClient.Dial("tcp", addr)
		if err != nil {
			return nil, fmt.Errorf("cannot connect to %s with bastion: %v", address.Name, err)
		}

		ncc, chans, reqs, err := ssh.NewClientConn(conn, addr, c.config)
		if err != nil {
			return nil, err
		}

		sshClient := ssh.NewClient(ncc, chans, reqs)
		return c.newConnection(sshClient, address)
	}

	sshClient, err := ssh.Dial("tcp", addr, c.config)
	if err != nil {
		return nil, fmt.Errorf("cannot connect to %s: %v", address.Name, err)
	}

	return c.newConnection(sshClient, address)
}

func (c *sshClient) newConnection(sshClient *ssh.Client, address *hosts.Address) (Connection, error) {
	sshAgentPresent := c.sshAgent != nil

	if sshAgentPresent {
		err := agent.ForwardToAgent(sshClient, c.sshAgent)
		if err != nil {
			return nil, fmt.Errorf("cannot forward ssh agent auth: %w", err)
		}
	}

	return &SSHConnection{
		address:         address,
		sshClient:       sshClient,
		sshAgentPresent: sshAgentPresent,
	}, nil
}

func (c *sshClient) Close() {
	if c.bastionClient != nil {
		err := c.bastionClient.Close()
		if err != nil {
			log.Printf("cannot close ssh connection to %s: %v", bastionAddr, err)
		}
	}
}
