package ytc

import (
	"crypto/aes"
	"crypto/cipher"
	"fmt"

	"a.yandex-team.ru/library/go/core/xerrors"
	"a.yandex-team.ru/passport/infra/daemons/historydb_api2/internal/errs"
	"a.yandex-team.ru/passport/infra/libs/go/compressor"
	"a.yandex-team.ru/passport/infra/libs/go/keys"
)

type EncryptedData struct {
	Version uint64 `yson:"v"`

	KeyID uint64 `yson:"keyid"`
	Iv    []byte `yson:"iv"`
	Text  []byte `yson:"text"`
	Tag   []byte `yson:"tag"`

	Codec *string `yson:"codec"`
	Size  *uint64 `yson:"size"`
}

var codecFromString = map[string]compressor.CompressionCodecType{
	"gzip":   compressor.GZip,
	"brotli": compressor.Brotli,
	"zstd":   compressor.ZStd,
}

func (data *EncryptedData) Decrypt(keyMap *keys.KeyMap) ([]byte, error) {
	res, err := data.decryptImpl(keyMap)
	if err != nil {
		return nil, &errs.TemporaryError{
			Type:    errs.DecryptError,
			Message: fmt.Sprintf("failed to decrypt encrypted data from yt: %v", err),
		}
	}

	return res, nil
}

func (data *EncryptedData) decryptImpl(keyMap *keys.KeyMap) ([]byte, error) {
	if data.Version != 1 {
		return nil, xerrors.Errorf("unsupported version: %d", data.Version)
	}

	decrypted, err := decrypt(keyMap.GetKeyNum(data.KeyID), data.Iv, data.Text, data.Tag)
	if err != nil {
		return nil, err
	}

	if data.Codec != nil {
		if data.Size == nil {
			return nil, xerrors.New("missing 'size' for compressed data")
		}

		codec, exists := codecFromString[*data.Codec]
		if !exists {
			return nil, xerrors.Errorf("unsupported compression codec: %s", *data.Codec)
		}

		return compressor.Decompress(*data.Size, decrypted, codec)
	}

	return decrypted, nil
}

func decrypt(key, iv, text, tag []byte) ([]byte, error) {
	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, xerrors.Errorf("failed to create cipher block: %s", err)
	}
	aesgcm, err := cipher.NewGCM(block)
	if err != nil {
		return nil, xerrors.Errorf("failed to create cipher GCM: %s", err)
	}

	ciphertext := make([]byte, 0, len(text)+len(tag))
	ciphertext = append(ciphertext, text...)
	ciphertext = append(ciphertext, tag...)
	decrypted, err := aesgcm.Open(ciphertext[:0], iv, ciphertext, nil)
	if err != nil {
		return nil, xerrors.Errorf("failed to decrypt blob: %s", err)
	}

	return decrypted, nil
}
