package config

import (
	"context"
	"errors"
	"fmt"
	"io"
	"os"
	"path/filepath"
	"runtime"
	"strings"

	"github.com/heetch/confita"
	"github.com/heetch/confita/backend/file"
	"gopkg.in/yaml.v2"

	"a.yandex-team.ru/library/go/core/log"
	"a.yandex-team.ru/security/skotty/skotty/internal/agent/agentkey"
	"a.yandex-team.ru/security/skotty/skotty/internal/confirm"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring/filering"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring/keychain"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring/yubiring"
	"a.yandex-team.ru/security/skotty/skotty/internal/passutil"
	"a.yandex-team.ru/security/skotty/skotty/internal/paths"
	"a.yandex-team.ru/security/skotty/skotty/internal/pinstore/anypin"
	"a.yandex-team.ru/security/skotty/skotty/internal/socket"
	"a.yandex-team.ru/security/skotty/skotty/pkg/sshutil/sshclient"
)

const DefaultEnrollmentService = "https://skotty.sec.yandex-team.ru"

var ErrSocketNotFound = errors.New("socket not found")

type KeyPurpose = keyring.KeyPurpose
type KeyringKind = keyring.Kind
type SocketKind = socket.Kind
type SecretVal = passutil.SecretVal

type EnrollInfo struct {
	EnrollID    string `yaml:"enrollment_id"`
	TokenSerial string `yaml:"token_serial"`
	User        string `yaml:"user"`
}

type Socket struct {
	Name        string             `yaml:"name"`
	Kind        SocketKind         `yaml:"kind"`
	SameAs      string             `yaml:"same_as,omitempty"`
	NotifyUsage bool               `yaml:"notify_usage"`
	Path        string             `yaml:"path"`
	Confirm     ConfirmKind        `yaml:"confirm,omitempty"`
	Keys        []KeyPurpose       `yaml:"keys,omitempty"`
	KeysOrder   []agentkey.KeyKind `yaml:"keys_order,omitempty"`
}

type KeyringYubikey struct {
	Serial uint32           `yaml:"serial"`
	PIN    *anypin.Provider `yaml:"pin"`
}

type KeyringKeychain struct {
	Collection string `yaml:"collection,omitempty"`
}

type KeyringFiles struct {
	BasePath   string           `yaml:"base_path"`
	Passphrase *anypin.Provider `yaml:"passphrase"`
}

type Keyring struct {
	Kind       KeyringKind     `yaml:"type"`
	Yubikey    KeyringYubikey  `yaml:"yubikey,omitempty"`
	Keychain   KeyringKeychain `yaml:"keychain,omitempty"`
	Files      KeyringFiles    `yaml:"files,omitempty"`
	Keys       []KeyPurpose    `yaml:"available_keys"`
	RenewToken SecretVal       `yaml:"renew_token,omitempty"`
	EnrollInfo EnrollInfo      `yaml:"enroll_info"`
}

type Startup struct {
	ExportAuthSock  bool `yaml:"export_auth_sock"`
	ReplaceAuthSock bool `yaml:"replace_auth_sock"`
}

type Confirm struct {
	Kind    confirm.Kind `yaml:"kind"`
	Program string       `yaml:"program,omitempty"`
}

type Config struct {
	LogLevel          string   `yaml:"log_level"`
	AgentLogPath      string   `yaml:"agent_log_path"`
	CtlSocketPath     string   `yaml:"ctl_socket_path"`
	SSHKeysPath       string   `yaml:"ssh_keys_path"`
	Sockets           []Socket `yaml:"sockets,omitempty"`
	SSHAuthSock       string   `yaml:"ssh_auth_sock,omitempty"`
	Startup           Startup  `yaml:"startup"`
	EnrollmentService string   `yaml:"enrollment_service,omitempty"`
	Keyring           Keyring  `yaml:"keyring"`
	Confirm           Confirm  `yaml:"confirm"`
}

func (c *Config) Build() (*Config, error) {
	if c.SSHKeysPath == "" {
		keysPath, err := paths.SSHKeys()
		if err != nil {
			return nil, fmt.Errorf("failed to determine ssh keys dir: %w", err)
		}

		c.SSHKeysPath = keysPath
	}

	if c.CtlSocketPath == "" {
		sockPath, err := paths.CtlSocket()
		if err != nil {
			return nil, fmt.Errorf("failed to determine supervisor socket: %w", err)
		}

		c.CtlSocketPath = sockPath
	}

	isAbsSocketPath := func(path string) bool {
		if filepath.IsAbs(path) {
			return true
		}

		if strings.HasPrefix(path, `\\.\pipe\`) || strings.HasPrefix(path, `//./pipe/`) {
			return true
		}

		return false
	}

	socketNames := make(map[string]struct{})
	socketPaths := make(map[string]struct{})
	sockets := make([]Socket, 0, len(c.Sockets))
	for i, s := range c.Sockets {
		if s.Name == "" {
			return nil, fmt.Errorf("unnamed socket: %d", i)
		}

		if _, ok := socketNames[s.Name]; ok {
			return nil, fmt.Errorf("duplicate socket name: %s", s.Name)
		}
		socketNames[s.Name] = struct{}{}

		if s.Kind == socket.KindNone {
			s.Kind = socket.KindUnix
		}

		if s.SameAs != "" && len(s.Keys) > 0 {
			return nil, fmt.Errorf("can't use 'same_as' and 'keys' options on the same socket: %s", s.Name)
		}

		if s.Path != socket.NoPath {
			switch {
			case s.Path == "":
				socksPath, err := paths.Sockets()
				if err != nil {
					return nil, fmt.Errorf("failed to determine skotty sockets dir: %w", err)
				}

				s.Path = filepath.Join(socksPath, filepath.Base(s.Name)+".sock")
			case !isAbsSocketPath(s.Path):
				socksPath, err := paths.Sockets()
				if err != nil {
					return nil, fmt.Errorf("failed to determine skotty sockets dir: %w", err)
				}

				s.Path = filepath.Join(socksPath, s.Path)
			}

			if _, ok := socketPaths[s.Path]; ok {
				return nil, fmt.Errorf("duplicate socket path: %s", s.Path)
			}
			socketPaths[s.Path] = struct{}{}
		}

		if s.Kind == socket.KindPageant {
			// special case for duplicate pageant sockets
			origName := s.Name

			s.Name = fmt.Sprintf("%s-pipe", origName)
			s.Kind = socket.KindPageantPipe
			sockets = append(sockets, s)

			s.Name = fmt.Sprintf("%s-window", origName)
			s.Kind = socket.KindPageantWindow
			sockets = append(sockets, s)
		} else {
			sockets = append(sockets, s)
		}
	}

	c.Sockets = sockets
	return c, nil
}

func (c *Config) NewKeyring() (keyring.Keyring, error) {
	switch c.Keyring.Kind {
	case keyring.KindYubikey:
		return yubiring.NewYubiring(c.Keyring.Yubikey.Serial, c.Keyring.Yubikey.PIN)
	case keyring.KindKeychain:
		return keychain.NewKeychain(c.Keyring.Keychain.Collection)
	case keyring.KindFiles:
		return filering.NewFilering(c.Keyring.Files.BasePath, c.Keyring.Files.Passphrase)
	default:
		return nil, fmt.Errorf("unsupported keyring: %s", c.Keyring.Kind)
	}
}

func (c *Config) Socket(name string) (Socket, error) {
	for _, sock := range c.Sockets {
		if sock.Name == name {
			return sock, nil
		}
	}

	return Socket{}, ErrSocketNotFound
}

func Load(cfgPath string, strict bool) (*Config, error) {
	enrollService := DefaultEnrollmentService
	if v, ok := os.LookupEnv("SKOTTY_ENROLL_SERVICE"); ok {
		enrollService = v
	}

	agentLog, err := paths.AgentLog()
	if err != nil {
		agentLog = filepath.Join(os.TempDir(), agentLog)
	}

	cfg := &Config{
		LogLevel:          log.InfoString,
		AgentLogPath:      agentLog,
		EnrollmentService: enrollService,
		SSHAuthSock:       sshclient.BestClient().SocketName(socket.NameDefault),
		Startup: Startup{
			ExportAuthSock:  true,
			ReplaceAuthSock: runtime.GOOS == "darwin",
		},
		Confirm: Confirm{
			Kind: confirm.KindSSHAskPass,
		},
	}

	if cfgPath == "" {
		return cfg.Build()
	}

	if _, err := os.Stat(cfgPath); !strict && err != nil {
		return cfg.Build()
	}

	loader := confita.NewLoader(file.NewBackend(cfgPath))
	if err := loader.Load(context.Background(), cfg); err != nil {
		return nil, err
	}

	return cfg.Build()
}

func Save(cfg *Config, dest string) error {
	out, err := (*cfg).Build()
	if err != nil {
		return fmt.Errorf("invalid config: %w", err)
	}

	if out.EnrollmentService == DefaultEnrollmentService {
		out.EnrollmentService = ""
	}

	outBytes, err := yaml.Marshal(out)
	if err != nil {
		return err
	}

	_ = backupFile(dest)
	return os.WriteFile(dest, outBytes, 0o600)
}

func backupFile(filename string) error {
	srcMode, err := os.Stat(filename)
	if err != nil {
		return err
	}

	src, err := os.Open(filename)
	if err != nil {
		return err
	}
	defer func() { _ = src.Close() }()

	dst, err := os.OpenFile(filename+".old", os.O_RDWR|os.O_CREATE|os.O_TRUNC, srcMode.Mode())
	if err != nil {
		return err
	}
	defer func() { _ = dst.Close() }()

	_, err = io.Copy(dst, src)
	if err != nil {
		_ = os.Remove(dst.Name())
		return err
	}

	err = os.Chmod(dst.Name(), srcMode.Mode())
	if err != nil {
		_ = os.Remove(dst.Name())
		return err
	}

	return nil
}
