import logging

from .config import ENV_CONFIG
from kmsclient.client import KmsRoundRobinClient

logger = logging.getLogger(__name__)


def _create_kms_client(token):
    logger.debug(f"KMS with {ENV_CONFIG.KMS_HOST}:{ENV_CONFIG.KMS_PORT}")
    return KmsRoundRobinClient(addrs=[f"{ENV_CONFIG.KMS_HOST}:{ENV_CONFIG.KMS_PORT}"])


def decrypt(ciphertext: bytes, token: str, aad: str, key_id: str) -> str:
    aad = bytes(aad, encoding='utf-8')
    with _create_kms_client(token) as kms:
        try:
            return kms.decrypt(key_id, aad, ciphertext, token).decode('utf-8')
        except Exception as e:
            raise Exception(f"Couldn't decrypt {ciphertext}: {e}")


def encrypt(plaintext: str, token: str, aad: str, key_id: str) -> bytes:
    plaintext = bytes(plaintext, encoding='utf-8')
    aad = bytes(aad, encoding='utf-8')
    with _create_kms_client(token) as kms:
        return kms.encrypt(key_id, aad, plaintext, token)
