//go:build windows
// +build windows

package wincreds

import (
	"encoding/json"
	"fmt"
	"reflect"
	"syscall"
	"unsafe"

	"golang.org/x/sys/windows"

	"a.yandex-team.ru/security/skotty/skotty/internal/keyring"
	"a.yandex-team.ru/security/skotty/skotty/internal/keyring/keychain/internal/chaintypes"
)

const (
	winCredPersistLocalMachine uint32 = 0x2
	winCredTypeGeneric         uint32 = 0x1
)

var _ chaintypes.Service = (*Service)(nil)
var _ chaintypes.Session = (*Session)(nil)

var (
	modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")

	procCredRead  = modadvapi32.NewProc("CredReadW")
	procCredWrite = modadvapi32.NewProc("CredWriteW")
	procCredFree  = modadvapi32.NewProc("CredFree")
)

type Service struct {
	collection string
}

type Session struct {
	collection string
}

// CredentialAttribute represents an application-specific attribute of a credential.
type CredentialAttribute struct {
	Keyword string
	Value   []byte
}

// Credential is the basic credential structure.
// A credential is identified by its target name.
// The actual credential secret is available in the CredentialBlob field.
type Credential struct {
	TargetName     string
	CredentialBlob []byte
	Attributes     []CredentialAttribute
}

// https://docs.microsoft.com/en-us/windows/win32/api/wincred/ns-wincred-credentialw
type sysCredential struct {
	Flags              uint32
	Type               uint32
	TargetName         *uint16
	Comment            *uint16
	LastWritten        syscall.Filetime
	CredentialBlobSize uint32
	CredentialBlob     uintptr
	Persist            uint32
	AttributeCount     uint32
	Attributes         uintptr
	TargetAlias        *uint16
	UserName           *uint16
}

// https://docs.microsoft.com/en-us/windows/win32/api/wincred/ns-wincred-credential_attributew
type sysCredentialAttribute struct {
	Keyword   *uint16
	Flags     uint32
	ValueSize uint32
	Value     uintptr
}

func NewService(collection string) (*Service, error) {
	return &Service{
		collection: collection,
	}, nil
}

func (s *Service) Session() (chaintypes.Session, error) {
	return &Session{
		collection: s.collection,
	}, nil
}

func (s *Service) Close() {}

func (t *Session) FetchKeyPair(keyType keyring.KeyPurpose) (chaintypes.KeyPair, error) {
	cred, err := t.fetchCredential(keyType)
	if err != nil {
		return chaintypes.KeyPair{}, fmt.Errorf("failed to fetch credential: %w", err)
	}

	var out chaintypes.KeyPair
	err = json.Unmarshal(cred.CredentialBlob, &out)
	if err != nil {
		return chaintypes.KeyPair{}, fmt.Errorf("failed to unmarshal keypair: %w", err)
	}

	return out, nil
}

func (t *Session) fetchCredential(keyType keyring.KeyPurpose) (*Credential, error) {
	var sysCred *sysCredential
	ret, _, err := procCredRead.Call(
		uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(t.targetNameForKey(keyType)))),
		uintptr(winCredTypeGeneric),
		0,
		uintptr(unsafe.Pointer(&sysCred)),
	)
	if ret == 0 {
		return nil, err
	}
	defer func() { _, _, _ = procCredFree.Call(uintptr(unsafe.Pointer(sysCred))) }()

	out := Credential{
		TargetName:     windows.UTF16PtrToString(sysCred.TargetName),
		CredentialBlob: goBytes(sysCred.CredentialBlob, sysCred.CredentialBlobSize),
		Attributes:     make([]CredentialAttribute, sysCred.AttributeCount),
	}

	p := unsafe.Pointer((*reflect.StringHeader)(unsafe.Pointer(&sysCred.Attributes)).Data)
	var sysAttrs []sysCredentialAttribute
	hdr := (*reflect.SliceHeader)(unsafe.Pointer(&sysAttrs))
	hdr.Data = uintptr(p)
	hdr.Cap = int(sysCred.AttributeCount)
	hdr.Len = int(sysCred.AttributeCount)

	for i, attr := range sysAttrs {
		out.Attributes[i] = CredentialAttribute{
			Keyword: windows.UTF16PtrToString(attr.Keyword),
			Value:   goBytes(attr.Value, attr.ValueSize),
		}
	}

	return &out, nil
}

func (t *Session) SaveKeyPair(keyType keyring.KeyPurpose, keyPair chaintypes.KeyPair) error {
	secBlob, err := json.Marshal(keyPair)
	if err != nil {
		return fmt.Errorf("failed to marshal certificate: %w", err)
	}

	cred := &Credential{
		TargetName:     t.targetNameForKey(keyType),
		CredentialBlob: secBlob,
		Attributes: []CredentialAttribute{
			{
				Keyword: "service",
				Value:   []byte("skotty"),
			},
			{
				Keyword: "key-type",
				Value:   []byte(keyType.String()),
			},
		},
	}

	err = t.saveCredential(cred)
	if err != nil {
		return err
	}
	return nil
}

func (t *Session) saveCredential(cred *Credential) error {
	sysCred := &sysCredential{
		Flags:      0,
		Type:       winCredTypeGeneric,
		TargetName: windows.StringToUTF16Ptr(cred.TargetName),
		Persist:    winCredPersistLocalMachine,
	}

	sysCred.CredentialBlobSize = uint32(len(cred.CredentialBlob))
	if sysCred.CredentialBlobSize > 0 {
		sysCred.CredentialBlob = uintptr(unsafe.Pointer(&cred.CredentialBlob[0]))
	}

	sysCred.AttributeCount = uint32(len(cred.Attributes))
	attributes := make([]sysCredentialAttribute, len(cred.Attributes))
	for i, attr := range cred.Attributes {
		sysAttr := sysCredentialAttribute{
			Keyword:   windows.StringToUTF16Ptr(attr.Keyword),
			Flags:     0,
			ValueSize: uint32(len(attr.Value)),
		}
		if sysAttr.ValueSize > 0 {
			sysAttr.Value = uintptr(unsafe.Pointer(&attr.Value[0]))
		}

		attributes[i] = sysAttr
	}

	if sysCred.AttributeCount > 0 {
		sysCred.Attributes = uintptr(unsafe.Pointer(&attributes[0]))
	}

	ret, _, err := procCredWrite.Call(uintptr(unsafe.Pointer(sysCred)), 0)
	if ret == 0 {
		return err
	}
	return nil
}

func (t *Session) Close() {}

func (t *Session) targetNameForKey(keyType keyring.KeyPurpose) string {
	if t.collection == "" {
		return fmt.Sprintf("%s@skotty", keyType)
	}

	return fmt.Sprintf("%s@skotty@%s", keyType, t.collection)
}

func goBytes(src uintptr, len uint32) []byte {
	if src == uintptr(0) {
		return []byte{}
	}

	p := unsafe.Pointer((*reflect.StringHeader)(unsafe.Pointer(&src)).Data)
	var b []byte
	hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b))
	hdr.Data = uintptr(p)
	hdr.Cap = int(len)
	hdr.Len = int(len)

	rv := make([]byte, len)
	copy(rv, b)
	return rv
}
