import base64
import struct
import random
import StringIO
import Crypto.Cipher.AES


class CryptoException(Exception):
    """
    Base class for exceptions happened in `sandbox.common.crypto` module
    """


class DecodingException(CryptoException):
    """
    Error has occurred in `base64` module on data decryption
    """


class CipherException(CryptoException):
    """
    Error has occurred in `Crypto` module on data decryption
    """


class UnpackingException(CryptoException):
    """
    Error has occurred during processing of decrypted data
    """


class AES(object):
    """
        Encapsulates methods for working with AES symmetric cipher
    """
    BLOCK_SIZE = 32  # block size in bytes
    MAGIC = "DATA"  # special byte sequence used for verify correctness of decryption

    def __init__(self, key=None):
        """
            Initialize AES

            :param key: AES key, if key is None then random key would be generated
        """
        self._key = self.generate_key() if key is None else key
        self._cipher = Crypto.Cipher.AES.new(self._key)

    @classmethod
    def generate_key(cls):
        key = random.getrandbits(cls.BLOCK_SIZE * 8)
        return struct.pack("B" * cls.BLOCK_SIZE, *(key >> (i * 8) & 0xFF for i in xrange(cls.BLOCK_SIZE)))

    @property
    def key(self):
        return self._key

    def encrypt(self, data, use_base64=True, use_salt=False):
        """
        Encrypt data with AES algorithm

        :param data: data to encrypt
        :param use_base64: if True, additionally encode result with base64
        :param use_salt: prefix input with salt before encrypting
        :return str: encrypted data
        """
        packed = StringIO.StringIO()

        packed.write(struct.pack("<I", len(self.MAGIC) + len(data)))
        if use_salt:
            packed.write(struct.pack("<I", random.randint(0, 2 ** 31 - 1)))
        packed.write(self.MAGIC)
        packed.write(data)
        packed.write((self.BLOCK_SIZE - packed.tell() % self.BLOCK_SIZE) * "\0")

        encoded_data = self._cipher.encrypt(packed.getvalue())
        if use_base64:
            encoded_data = base64.b64encode(encoded_data)
        return encoded_data

    def decrypt(self, encrypted_data, use_base64=True):
        """
        Decode data with AES algorithm

        :param encrypted_data: data to decrypt
        :param use_base64: if True, preliminarily decode data with base64
        :return str: decrypted data
        """

        if use_base64:
            try:
                encrypted_data = base64.b64decode(encrypted_data)
            except TypeError as exc:
                raise DecodingException(exc)

        try:
            data = self._cipher.decrypt(encrypted_data)
        except (ValueError, TypeError) as ex:
            raise CipherException(ex)

        packer = struct.Struct("<I")
        offset = packer.size
        try:
            (size,) = packer.unpack(data[:offset])
        except struct.error as ex:
            raise UnpackingException(ex)
        if not data[offset:].startswith(self.MAGIC):
            offset += packer.size
        if data[offset:].startswith(self.MAGIC):
            offset += len(self.MAGIC)
            return data[offset: offset + size - len(self.MAGIC)]  # because size = len(magic) + len(payload)
