package decryptor

import (
	"bytes"
	"crypto/aes"
	"crypto/cipher"
	"encoding/hex"
	"errors"
	"strconv"

	"a.yandex-team.ru/passport/infra/daemons/yasmsd/internal/logs"
)

/*
Логика работы с зашифрованными сообщениями
*/

type Decryptor struct {
	keyring *Keyring   // Ключи для расшифровки
	logs    *logs.Logs // Логи
}

type EncryptedData struct {
	keyID int
	iv    []byte
	body  []byte
	tag   []byte
}

func NewDecryptor(keyring *Keyring, logs *logs.Logs) *Decryptor {
	return &Decryptor{
		keyring: keyring,
		logs:    logs,
	}
}

func looksLikeEncrypted(text []byte) bool {
	return bytes.HasPrefix(text, []byte("1:")) &&
		len(text) >= 64 &&
		bytes.Count(text, []byte(":")) == 4 &&
		bytes.IndexFunc(bytes.ToLower(text), func(c rune) bool {
			return (c < '0' || c > '9') &&
				(c < 'a' || c > 'f') &&
				c != ':'
		}) < 0
}

func decodeHexField(src []byte, name string) ([]byte, error) {
	dst := make([]byte, hex.DecodedLen(len(src)))
	n, err := hex.Decode(dst, src)
	if err != nil || len(dst) == 0 {
		return nil, errors.New("bad " + name + " format: " + string(src))
	}

	return dst[:n], nil
}

func parseEncrypted(text []byte) (*EncryptedData, error) {
	parts := bytes.Split(text, []byte(":"))
	if len(parts) != 5 {
		return nil, errors.New("bad field count")
	}

	if string(parts[0]) != "1" {
		return nil, errors.New("bad version: " + string(parts[0]))
	}

	keyID, err := strconv.ParseInt(string(parts[1]), 10, 32)
	if err != nil {
		return nil, errors.New("bad key id: " + string(parts[1]))
	}

	iv, err := decodeHexField(parts[2], "iv")
	if err != nil {
		return nil, err
	}

	body, err := decodeHexField(parts[3], "body")
	if err != nil {
		return nil, err
	}

	tag, err := decodeHexField(parts[4], "tag")
	if err != nil {
		return nil, err
	}

	return &EncryptedData{
			keyID: int(keyID),
			iv:    iv,
			body:  body,
			tag:   tag,
		},
		nil
}

func (d *Decryptor) DecryptText(text []byte, phone string, rowID uint64) ([]byte, error) {
	if !looksLikeEncrypted(text) {
		// TODO: for now we allow plain text messages, later will be error
		return text, nil
	}

	data, err := parseEncrypted(text)
	if err != nil {
		d.logs.General.WriteDebug(logs.ComponentDecryptor, "smsid=%d failed to parse encrypted text: %s", rowID, err)
		return nil, errors.New("failed to parse encrypted text")
	}

	key := d.keyring.getKey(data.keyID)
	if key == nil {
		d.logs.General.WriteWarning(logs.ComponentDecryptor, "smsid=%d key id %d not found in keyring", rowID, data.keyID)
		return nil, errors.New("couldn't decrypt message, key id not found")
	}

	ciphertext := make([]byte, 0, len(data.body)+len(data.tag))
	ciphertext = append(ciphertext, data.body...)
	ciphertext = append(ciphertext, data.tag...)

	block, err := aes.NewCipher(key)
	if err != nil {
		d.logs.General.WriteWarning(logs.ComponentDecryptor, "smsid=%d failed to create AES Cipher", rowID)
		return nil, err
	}

	aesgcm, err := cipher.NewGCM(block)
	if err != nil {
		d.logs.General.WriteWarning(logs.ComponentDecryptor, "smsid=%d failed to create GCM Cipher", rowID)
		return nil, err
	}

	plaintext, err := aesgcm.Open(nil, data.iv, ciphertext, []byte(phone))
	if err != nil {
		d.logs.General.WriteWarning(logs.ComponentDecryptor, "smsid=%d failed to open encrypted message: %s", rowID, err)
		return nil, err
	}

	return plaintext, nil
}
