package passutil

import (
	"crypto/aes"
	"crypto/cipher"
	"crypto/hmac"
	"crypto/rand"
	"crypto/sha256"
	"encoding/hex"
	"errors"
	"fmt"
	"io"
	"strings"
	"sync"

	"gopkg.in/yaml.v2"

	"a.yandex-team.ru/security/libs/go/machineid"
)

const appID = "skotty"
const encPrefix = "secret:"

var (
	aesKey     []byte
	aesKeyHash string
	aesKeyErr  error
	aesKeyOnce sync.Once
)

var _ yaml.Unmarshaler = (*SecretVal)(nil)
var _ yaml.Marshaler = (*SecretVal)(nil)

type SecretVal string

func (s SecretVal) String() string {
	return string(s)
}

func (s SecretVal) MarshalYAML() (interface{}, error) {
	return s.ToEncrypted()
}

func (s *SecretVal) UnmarshalYAML(unmarshal func(interface{}) error) error {
	var v string
	if err := unmarshal(&v); err != nil {
		return err
	}

	return s.FromAny(v)
}

func (s SecretVal) ToEncrypted() (string, error) {
	if len(s) == 0 {
		return "", nil
	}

	key, keyHash, err := AESKey()
	if err != nil {
		return "", fmt.Errorf("failed to generate secret key: %w", err)
	}

	data, err := Encrypt([]byte(s), key)
	if err != nil {
		return "", fmt.Errorf("failed to encrypt seret value: %w", err)
	}

	return fmt.Sprintf("%s%s:%s", encPrefix, hex.EncodeToString(data), keyHash), nil
}

func (s *SecretVal) FromAny(v string) error {
	if !strings.HasPrefix(v, encPrefix) {
		*s = SecretVal(v)
		return nil
	}

	v = v[len(encPrefix):]

	key, keyHash, err := AESKey()
	if err != nil {
		return fmt.Errorf("failed to generate secret key: %w", err)
	}

	var cryptedVal, expectedKey string
	if idx := strings.Index(v, ":"); idx <= 0 {
		return errors.New("invalid secret val format")
	} else {
		cryptedVal = v[:idx]
		expectedKey = v[idx+1:]
	}

	if expectedKey != keyHash {
		return fmt.Errorf("possible machine-id change: %s != %s", expectedKey, keyHash)
	}

	ciphertext, err := hex.DecodeString(cryptedVal)
	if err != nil {
		return err
	}

	val, err := Decrypt(ciphertext, key)
	if err != nil {
		return fmt.Errorf("failed to decrypt secret value: %w", err)
	}

	*s = SecretVal(val)
	return nil
}

func AESKey() ([]byte, string, error) {
	aesKeyOnce.Do(func() {
		machineID, err := machineid.ID()
		if err != nil {
			aesKeyErr = err
			return
		}

		mac := hmac.New(sha256.New, []byte(machineID))
		mac.Write([]byte(appID))
		aesKey = mac.Sum(nil)
		keySum := sha256.Sum256(aesKey)
		aesKeyHash = hex.EncodeToString(keySum[:])
	})

	return aesKey, aesKeyHash, aesKeyErr
}

func Encrypt(plaintext []byte, key []byte) ([]byte, error) {
	key = prepareKey(key)
	c, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}

	gcm, err := cipher.NewGCM(c)
	if err != nil {
		return nil, err
	}

	nonce := make([]byte, gcm.NonceSize())
	if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
		return nil, err
	}

	return gcm.Seal(nonce, nonce, plaintext, nil), nil
}

func Decrypt(ciphertext []byte, key []byte) ([]byte, error) {
	key = prepareKey(key)
	c, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}

	gcm, err := cipher.NewGCM(c)
	if err != nil {
		return nil, err
	}

	nonceSize := gcm.NonceSize()
	if len(ciphertext) < nonceSize {
		return nil, errors.New("ciphertext too short")
	}

	nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
	return gcm.Open(nil, nonce, ciphertext, nil)
}

func prepareKey(key []byte) []byte {
	if len(key) == 32 {
		return key
	}

	newKey := sha256.Sum256(key)
	return newKey[:]
}
