//go:build linux
// +build linux

package secretservice

import (
	"encoding/json"
	"errors"
	"fmt"
	"math/big"
	"sync"

	"github.com/godbus/dbus/v5"

	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring/keychain/internal/chaintypes"
)

const (
	dbusServicePath           = "/org/freedesktop/secrets"
	dbusInterfacePath         = "org.freedesktop.secrets"
	dbusCollectionInterface   = "org.freedesktop.Secret.Collection"
	dbusCollectionsInterface  = "org.freedesktop.Secret.Service.Collections"
	dbusPromptInterface       = "org.freedesktop.Secret.Prompt"
	dbusItemInterface         = "org.freedesktop.Secret.Item"
	dbusCallUnlock            = "org.freedesktop.Secret.Service.Unlock"
	dbusCallOpenSession       = "org.freedesktop.Secret.Service.OpenSession"
	dbusCallClose             = "org.freedesktop.Secret.Session.Close"
	dbusCallCreateCollection  = "org.freedesktop.Secret.Service.CreateCollection"
	dbusCallCreateItem        = "org.freedesktop.Secret.Collection.CreateItem"
	dbusCallSearchItems       = "org.freedesktop.Secret.Collection.SearchItems"
	dbusCallPromptPrompt      = "org.freedesktop.Secret.Prompt.Prompt"
	dbusCallGetSecret         = "org.freedesktop.Secret.Item.GetSecret"
	dbusSignalPromptCompleted = "org.freedesktop.Secret.Prompt.Completed"

	dbusAuthenticationInsecurePlain = "plain"
	dbusAuthenticationDHAES         = "dh-ietf1024-sha256-aes128-cbc-pkcs7"
)

var _ chaintypes.Service = (*Service)(nil)
var _ chaintypes.Session = (*Session)(nil)
var errCallDismissed = errors.New("dismissed")

type Service struct {
	dbusConn   *dbus.Conn
	sd         *SignalDispatcher
	collection dbus.ObjectPath
	mu         sync.Mutex
}

type Session struct {
	dbusConn       *dbus.Conn
	sd             *SignalDispatcher
	collectionPath dbus.ObjectPath
	sessionPath    dbus.ObjectPath
	public         *big.Int
	private        *big.Int
	aesKey         []byte
}

type Secret struct {
	Session     dbus.ObjectPath
	Parameters  []byte
	Value       []byte
	ContentType string
}

func NewService(collection string) (*Service, error) {
	conn, err := dbus.SessionBus()
	if err != nil {
		return nil, fmt.Errorf("failed to open dbus connection: %w", err)
	}

	err = conn.AddMatchSignal(
		dbus.WithMatchInterface(dbusPromptInterface),
		dbus.WithMatchMember("Completed"),
	)
	if err != nil {
		_ = conn.Close()
		return nil, fmt.Errorf("failed to add dbus events watcher: %w", err)
	}

	s := &Service{
		dbusConn:   conn,
		sd:         NewSignalDispatcher(conn),
		collection: dbus.ObjectPath(collection),
	}

	return s, nil
}

func (s *Service) Session() (chaintypes.Session, error) {
	s.mu.Lock()
	defer s.mu.Unlock()

	ourGroup := rfc2409SecondOakleyGroup()
	ourPrivate, ourPublic, err := ourGroup.NewKeypair()
	if err != nil {
		return nil, fmt.Errorf("failed to generate dbus keypair: %w", err)
	}

	var (
		algorithmOutput dbus.Variant
		sessionPath     dbus.ObjectPath
	)

	err = s.dbusConn.Object(dbusInterfacePath, dbusServicePath).
		Call(dbusCallOpenSession, 0, dbusAuthenticationDHAES, dbus.MakeVariant(ourPublic.Bytes())).
		Store(&algorithmOutput, &sessionPath)

	if err != nil {
		return nil, fmt.Errorf("failed to open new secret service session: %w", err)
	}

	theirPublicBigEndian, ok := algorithmOutput.Value().([]byte)
	if !ok {
		return nil, fmt.Errorf("failed to coerce algorithm output value: invalid response type %T", algorithmOutput.Value())
	}

	theirGroup := rfc2409SecondOakleyGroup()
	theirPublic := new(big.Int)
	theirPublic.SetBytes(theirPublicBigEndian)
	aesKey, err := theirGroup.keygenHKDFSHA256AES128(theirPublic, ourPrivate)
	if err != nil {
		return nil, err
	}

	sess := &Session{
		dbusConn:       s.dbusConn,
		sd:             s.sd,
		sessionPath:    sessionPath,
		collectionPath: s.collection,
		public:         ourPublic,
		private:        ourPrivate,
		aesKey:         aesKey,
	}
	if err := sess.Unlock(); err != nil {
		sess.Close()
		return nil, fmt.Errorf("can't unlock database: %w", err)
	}

	return sess, nil
}

func (s *Service) Close() {
	_ = s.dbusConn.Close()
	s.sd.Close()
}

func (t *Session) SaveKeyPair(keyType keyring.KeyPurpose, keypair chaintypes.KeyPair) error {
	secBytes, err := json.Marshal(keypair)
	if err != nil {
		return fmt.Errorf("failed to marshal certificate: %w", err)
	}

	secret, err := t.newSecret(secBytes)
	if err != nil {
		return fmt.Errorf("failed to create secret for private key")
	}

	properties := map[string]dbus.Variant{
		dbusItemInterface + ".Label": dbus.MakeVariant(fmt.Sprintf("%s@skotty", keyType)),
		dbusItemInterface + ".Attributes": dbus.MakeVariant(map[string]string{
			"service":  "skotty",
			"key-type": keyType.String(),
		}),
	}

	var item, prompt dbus.ObjectPath
	err = t.dbusConn.Object(dbusInterfacePath, t.collectionPath).
		Call(dbusCallCreateItem, 0, properties, secret, true).
		Store(&item, &prompt)
	if err != nil {
		return fmt.Errorf("failed to save certificate in secret service: %w", err)
	}

	_, err = t.handlePrompt(prompt)
	return err
}

func (t *Session) FetchKeyPair(keyType keyring.KeyPurpose) (chaintypes.KeyPair, error) {
	var out chaintypes.KeyPair
	secretPath, err := t.findSecretPath(keyType)
	if err != nil {
		return out, fmt.Errorf("secret was not found: %w", err)
	}

	var secret Secret
	err = t.dbusConn.Object(dbusInterfacePath, secretPath).
		Call(dbusCallGetSecret, 0, t.sessionPath).
		Store(&secret)
	if err != nil {
		return out, fmt.Errorf("failed to get secret: %w", err)
	}

	rawSecretData, err := unauthenticatedAESCBCDecrypt(secret.Parameters, secret.Value, t.aesKey)
	if err != nil {
		return out, fmt.Errorf("failed to decrypt secret: %w", err)
	}

	err = json.Unmarshal(rawSecretData, &out)
	if err != nil {
		return out, fmt.Errorf("failed to unmarshal secret: %w", err)
	}

	return out, nil
}

func (t *Session) findSecretPath(keyType keyring.KeyPurpose) (dbus.ObjectPath, error) {
	var results []dbus.ObjectPath
	err := t.dbusConn.Object(dbusInterfacePath, t.collectionPath).
		Call(dbusCallSearchItems, 0, map[string]string{
			"service":  "skotty",
			"key-type": keyType.String(),
		}).
		Store(&results)
	if err != nil {
		return "", fmt.Errorf("failed to call .SearchItems on collection %s: %w", t.collectionPath, err)
	}

	if len(results) > 1 {
		return "", fmt.Errorf("found more than one secret: %d", len(results))
	}

	return results[0], nil
}

func (t *Session) Close() {
	obj := t.dbusConn.Object(dbusInterfacePath, t.sessionPath)
	_ = obj.Call(dbusCallClose, 0)
}

func (t *Session) createCollection(name string) (string, error) {
	properties := map[string]dbus.Variant{
		dbusCollectionInterface + ".Label": dbus.MakeVariant(name),
	}

	var collection, prompt dbus.ObjectPath
	err := t.dbusConn.Object(dbusInterfacePath, dbusServicePath).
		Call(dbusCallCreateCollection, 0, properties, "").
		Store(&collection, &prompt)

	if err != nil {
		return "", err
	}

	v, err := t.handlePrompt(prompt)
	if err != nil {
		return "", err
	}

	if v.String() != "" {
		return v.String(), nil
	}

	return string(collection), nil
}

func (t *Session) newSecret(secretBytes []byte) (Secret, error) {
	iv, ciphertext, err := unauthenticatedAESCBCEncrypt(secretBytes, t.aesKey)
	if err != nil {
		return Secret{}, err
	}

	return Secret{
		Session:     t.sessionPath,
		Parameters:  iv,
		Value:       ciphertext,
		ContentType: "application/octet-stream",
	}, nil
}

func (t *Session) Unlock() error {
	var (
		unlocked []dbus.ObjectPath
		prompt   dbus.ObjectPath
	)

	err := t.dbusConn.Object(dbusInterfacePath, dbusServicePath).
		Call(dbusCallUnlock, 0, []dbus.ObjectPath{t.collectionPath}).
		Store(&unlocked, &prompt)

	if err != nil {
		return fmt.Errorf("failed to call .Unlock method: %w", err)
	}

	v, err := t.handlePrompt(prompt)
	if err != nil {
		return err
	}

	hasCollections := func(collections []dbus.ObjectPath) bool {
		for _, c := range collections {
			if c == t.collectionPath {
				return true
			}
		}

		return false
	}

	if hasCollections(unlocked) {
		return nil
	}

	switch collections := v.Value().(type) {
	case []dbus.ObjectPath:
		if hasCollections(collections) {
			return nil
		}
	}

	return fmt.Errorf("requested collection %q isn't unlocked", t.collectionPath)
}

func (t *Session) handlePrompt(prompt dbus.ObjectPath) (dbus.Variant, error) {
	out := dbus.MakeVariant("")
	if prompt == "/" {
		return out, nil
	}

	var (
		wg  sync.WaitGroup
		rsp []interface{}
	)

	wg.Add(1)
	go func() {
		defer wg.Done()

		rsp = t.sd.WaitEvent(dbusSignalPromptCompleted, prompt)
	}()

	obj := t.dbusConn.Object(dbusInterfacePath, prompt)
	err := obj.Call(dbusCallPromptPrompt, 0, "").Err
	if err != nil {
		return out, fmt.Errorf("failed to call prompt for object %q: %w", prompt, err)
	}

	wg.Wait()
	var dismissed bool
	err = dbus.Store(rsp, &dismissed, &out)
	if err != nil {
		return out, fmt.Errorf("failed to unmarshal prompt result: %w", err)
	}

	if dismissed {
		return out, errCallDismissed
	}

	return out, nil
}

func Collections() ([]string, error) {
	conn, err := dbus.SessionBus()
	if err != nil {
		return nil, fmt.Errorf("failed to open dbus connection: %w", err)
	}
	defer func() { _ = conn.Close() }()

	val, err := conn.Object(dbusInterfacePath, dbusServicePath).GetProperty(dbusCollectionsInterface)
	if err != nil {
		return nil, err
	}

	paths := val.Value().([]dbus.ObjectPath)
	out := make([]string, len(paths))
	for i, p := range paths {
		out[i] = string(p)
	}

	return out, nil
}

func CreateCollection(name string) (string, error) {
	conn, err := dbus.SessionBus()
	if err != nil {
		return "", fmt.Errorf("failed to open dbus connection: %w", err)
	}
	defer func() { _ = conn.Close() }()

	err = conn.AddMatchSignal(
		dbus.WithMatchInterface(dbusPromptInterface),
		dbus.WithMatchMember("Completed"),
	)
	if err != nil {
		return "", fmt.Errorf("failed to add dbus events watcher: %w", err)
	}

	sd := NewSignalDispatcher(conn)
	defer sd.Close()

	//TODO(buglloc): ugly
	tx := &Session{
		dbusConn: conn,
		sd:       sd,
	}
	return tx.createCollection(name)
}

func IsAvailable() (bool, error) {
	_, err := Collections()
	return err == nil, nil
}
