import binascii
import hashlib
import random

from Crypto.Cipher import AES
from Crypto.Util import Padding


class Encrypter:
    """
    https://a.yandex-team.ru/arc_vcs/direct/libs/utils/src/main/java/ru/yandex/direct/utils/crypt/Encrypter.java?rev=r7782876#L32
    """
    SALTED_PREFIX = b"Salted__"
    SALT_LEN = 8
    IV_LEN = 16
    KEY_LEN = 32

    def __init__(self, secret):
        if not isinstance(secret, bytes):
            raise TypeError("secret must be bytes")

        self.secret = secret

    def encrypt(self, text):
        msg = text.encode()
        salt = random.randbytes(self.SALT_LEN)
        return binascii.hexlify(self.encrypt_bytes(msg, self.secret, salt))

    def decrypt(self, encrypted):
        encrypted_bytes = binascii.unhexlify(encrypted)

        if not encrypted_bytes.startswith(self.SALTED_PREFIX):
            raise ValueError("encrypted message must start with '{}'".format(self.SALTED_PREFIX))

        prefix_and_salt_len = len(self.SALTED_PREFIX) + self.SALT_LEN
        salt = encrypted_bytes[len(self.SALTED_PREFIX):prefix_and_salt_len]
        encrypted_msg = encrypted_bytes[prefix_and_salt_len:]
        return self.decrypt_bytes(encrypted_msg, self.secret, salt).decode("utf-8")

    @classmethod
    def get_key_and_iv(cls, secret, salt):
        if len(salt) != cls.SALT_LEN:
            raise ValueError("salt must be of size '{}'".format(cls.SALT_LEN))

        desired_len = cls.KEY_LEN + cls.IV_LEN
        data = b""
        chunk = b""
        while (len(data) < desired_len):
            chunk = cls._md5_hash(chunk + secret + salt)
            data += chunk

        return data[:cls.KEY_LEN], data[cls.KEY_LEN:desired_len]

    @staticmethod
    def _md5_hash(data):
        h = hashlib.new("md5")
        h.update(data)
        return h.digest()

    @classmethod
    def encrypt_bytes(cls, msg, secret, salt):
        key, iv = cls.get_key_and_iv(secret, salt)
        cipher = AES.new(key, AES.MODE_CBC, iv)
        return cls.SALTED_PREFIX + salt + cipher.encrypt(Padding.pad(msg, AES.block_size))

    @classmethod
    def decrypt_bytes(cls, encrypted_msg, secret, salt):
        key, iv = cls.get_key_and_iv(secret, salt)
        decipher = AES.new(key, AES.MODE_CBC, iv)
        return Padding.unpad(decipher.decrypt(encrypted_msg), AES.block_size)
