package diag

import (
	"encoding/base64"
	"errors"
	"fmt"
	"os"
	"strconv"
	"strings"
	"time"

	"golang.org/x/crypto/ssh/agent"

	"a.yandex-team.ru/library/go/core/buildinfo"
	"a.yandex-team.ru/security/libs/go/pcsc"
	"a.yandex-team.ru/security/skotty/skotty/internal/config"
	"a.yandex-team.ru/security/skotty/skotty/internal/paths"
	"a.yandex-team.ru/security/skotty/skotty/internal/procutil"
	"a.yandex-team.ru/security/skotty/skotty/internal/socket"
	"a.yandex-team.ru/security/skotty/skotty/internal/version"
	"a.yandex-team.ru/security/skotty/skotty/pkg/sshutil"
	"a.yandex-team.ru/security/skotty/skotty/pkg/sshutil/sshclient"
)

var errNotSupported = errors.New("not supported")

type collector struct {
	name string
	fn   func(out *Writer) error
}

func collectors() []collector {
	return []collector{
		{
			name: "Skotty version",
			fn:   collectSkottyVersion,
		},
		{
			name: "OS Version",
			fn:   collectOS,
		},
		{
			name: "Skotty info",
			fn:   collectSkottyInfo,
		},
		{
			name: "Skotty yubikey list",
			fn:   collectSkottyYubikeyList,
		},
		{
			name: "List USB devices",
			fn:   collectUSBDevices,
		},
		{
			name: "Find Yubikey vendor devices",
			fn:   checkYubikeyVendorDevices,
		},
		{
			name: "Collect SSH env",
			fn:   collectSSHEnv,
		},
		{
			name: "Check PCSC",
			fn:   checkPCSC,
		},
		{
			name: "Collect PCSCD process",
			fn:   collectPCSCD,
		},
		{
			name: "Collect SSH signs",
			fn:   collectSSHSigns,
		},
	}
}

func collectSkottyVersion(out *Writer) error {
	out.Writeln("Skotty Version:", version.Full())
	if t, err := strconv.ParseInt(buildinfo.Info.BuildTimestamp, 10, 64); err == nil {
		out.Writeln("Build at:", time.Unix(t, 0).String())
	}

	return nil
}

func checkPCSC(_ *Writer) error {
	pc, err := pcsc.NewClient()
	if err != nil {
		return err
	}
	defer func() { _ = pc.Close() }()

	if err := pc.CheckCompatibility(); err != nil {
		return err
	}

	return nil
}

func collectSSHEnv(out *Writer) error {
	for _, e := range os.Environ() {
		if strings.HasPrefix(e, "SSH_") {
			out.Writeln(e)
		}
	}

	return nil
}

func collectSkottyInfo(out *Writer) error {
	return collectCommand(out, "skotty", "info")
}

func collectSkottyYubikeyList(out *Writer) error {
	return collectCommand(out, "skotty", "yubikey", "list")
}

func collectCommand(out *Writer, name string, args ...string) error {
	out.Writeln("run command: ", fmt.Sprintf("%s '%s'", name, strings.Join(args, "' '")))
	r, err := procutil.RunCommand(name, args...)
	if err != nil {
		return err
	}

	out.Writeln("stdout:")
	out.WriteString(r.Stdout)

	out.Writeln("\nstderr:")
	out.WriteString(r.Stderr)

	if r.ExitCode != 0 {
		return fmt.Errorf("unexpected exit code: %d", r.ExitCode)
	}

	return nil
}

func loadConfig(strict bool) (*config.Config, error) {
	cfgPath, err := paths.Config()
	if err != nil {
		return nil, err
	}

	return config.Load(cfgPath, strict)
}

func collectSSHSigns(out *Writer) error {
	cfg, err := loadConfig(true)
	if err != nil {
		return fmt.Errorf("failed to load config: %w", err)
	}

	sshClient := sshclient.BestClient()
	sock, err := cfg.Socket(sshClient.SocketName(socket.NameDefault))
	if err != nil {
		return fmt.Errorf("unable to get default agent socket: %w", err)
	}

	conn, err := dialSock(sock.Path)
	if err != nil {
		return fmt.Errorf("error connecting to agent: %w", err)
	}
	defer func() { _ = conn.Close() }()

	sshAgent := agent.NewClient(conn)
	agentKeys, err := sshAgent.List()
	if err != nil {
		return fmt.Errorf("error listing agent keys: %w", err)
	}

	out.Writeln("agent keys:")
	for _, key := range agentKeys {
		out.Writefln("- %s %s", key.Type(), sshutil.Fingerprint(key))
		out.Writefln("\t* pub: %s", key)
		sign, err := sshAgent.Sign(key, []byte("kek"))
		if err != nil {
			out.Writeln("\t* [!!!] sign failed:", err)
			continue
		}

		if err := key.Verify([]byte("kek"), sign); err != nil {
			out.Writefln("\t* [!!!] unable to check sign %s %s: %v",
				sign.Format,
				base64.StdEncoding.EncodeToString(sign.Blob),
				err,
			)
			continue
		}

		out.Writefln("\t* sign checked")
	}

	return nil
}
