package encrypt

import (
	"bytes"
	"crypto/rand"
	"encoding/base64"
	"errors"
	"fmt"
	"strings"

	"github.com/google/tink/go/aead/subtle"
	"github.com/google/tink/go/keyset"
	"github.com/google/tink/go/streamingaead"
	"golang.org/x/crypto/chacha20poly1305"
)

func NewKey() (string, error) {
	masterKey, err := subtle.NewXChaCha20Poly1305(GetRandomBytes(chacha20poly1305.KeySize))
	if err != nil {
		return "", fmt.Errorf("can't create master key: %w", err)
	}

	kh, err := keyset.NewHandle(streamingaead.AES256GCMHKDF4KBKeyTemplate())
	if err != nil {
		return "", fmt.Errorf("can't create keyset: %w", err)
	}

	var buf bytes.Buffer
	if err := kh.Write(keyset.NewBinaryWriter(&buf), masterKey); err != nil {
		return "", fmt.Errorf("failed to write keyset: %w", err)
	}

	return mergeKey(buf.Bytes(), masterKey.Key), nil
}

func ReadKey(key string) (*keyset.Handle, error) {
	rawKs, rawMk, err := splitKey(key)
	if err != nil {
		return nil, err
	}

	masterKey, err := subtle.NewXChaCha20Poly1305(rawMk)
	if err != nil {
		return nil, fmt.Errorf("can't create master key: %w", err)
	}

	ks, err := keyset.Read(keyset.NewBinaryReader(bytes.NewReader(rawKs)), masterKey)
	if err != nil {
		return nil, fmt.Errorf("failed to decrypt keyset: %v", err)
	}

	return ks, nil
}

func GetRandomBytes(n uint32) []byte {
	buf := make([]byte, n)
	_, err := rand.Read(buf)
	if err != nil {
		panic(err) // out of randomness, should never happen
	}

	return buf
}

func mergeKey(ks, masterKey []byte) string {
	return base64.StdEncoding.EncodeToString(ks) + ":" + base64.StdEncoding.EncodeToString(masterKey)
}

func splitKey(merged string) ([]byte, []byte, error) {
	parts := strings.Split(merged, ":")
	if len(parts) != 2 {
		return nil, nil, errors.New("invalid key format")
	}

	ks, err := base64.StdEncoding.DecodeString(parts[0])
	if err != nil {
		return nil, nil, fmt.Errorf("can't decode keyset: %w", err)
	}

	ms, err := base64.StdEncoding.DecodeString(parts[1])
	if err != nil {
		return nil, nil, fmt.Errorf("can't decode master key: %w", err)
	}

	return ks, ms, nil
}
