import base64
import datetime
from enum import IntEnum
import logging

from cryptography.exceptions import (
    InvalidKey,
    InvalidSignature,
    InvalidTag,
)
import lz4.frame
from passport.backend.core.crypto import aes_gcm
from passport.backend.core.lazy_loader import (
    lazy_loadable,
    LazyLoader,
)
from passport.backend.vault.api.errors import DecryptionError


logger = logging.getLogger(__name__)


class ValueType(IntEnum):
    AES = 0
    COMPRESSED_AES = 1


class KeyManager(object):
    def __init__(self, keys):
        self.default_index = None
        self._keys = dict(keys)
        self.default_index = keys[0][0]

    def get_key_for(self, index):
        if index not in self._keys:
            logger.warning('Not found key with index %s, defaulted to %s' % (index, self.default_index))
            return self._keys[self.default_index]
        return self._keys[index]

    def get_current_keypair(self):
        now = datetime.datetime.now()
        index = '%d%02d' % (now.year, now.month)
        return index, self.get_key_for(index)


class AES(KeyManager):
    def decode(self, index, value):
        try:
            iv, ciphertext, associated_data, auth_tag = value.encode('utf-8').split('.')
        except ValueError:
            raise DecryptionError()

        try:
            iv = base64.urlsafe_b64decode(iv)
            ciphertext = base64.urlsafe_b64decode(ciphertext)
            associated_data = base64.urlsafe_b64decode(associated_data)
            auth_tag = base64.urlsafe_b64decode(auth_tag)
        except TypeError as ex:
            raise DecryptionError(message=ex.message)
        try:
            value = aes_gcm.decrypt(self.get_key_for(index), associated_data, iv, ciphertext, auth_tag)
        except (InvalidKey, InvalidTag, InvalidSignature) as ex:
            raise DecryptionError(message=ex.message)

        return value

    def encode(self, value):
        index, key = self.get_current_keypair()
        iv, ciphertext, auth_tag = aes_gcm.encrypt(key, value, '')
        return index, ValueType.AES.value, '.'.join([
            base64.urlsafe_b64encode(v) for v in [iv, ciphertext, '', auth_tag]
        ])


class CompressedAES(AES):
    def decode(self, index, value):
        v = super(CompressedAES, self).decode(index, value)
        return lz4.frame.decompress(v)

    def encode(self, value):
        index, _, v = super(CompressedAES, self).encode(lz4.frame.compress(value))
        return index, ValueType.COMPRESSED_AES.value, v


class OptimalAES(AES):
    compressing_threshold = 0.7

    def decode(self, index, value):
        raise RuntimeError('Not implemented')  # pragma: no cover

    def encode(self, value):
        compressed_value = lz4.frame.compress(value)
        if len(compressed_value) < self.compressing_threshold * len(value):
            index, _, compressed_aesed_value = super(OptimalAES, self).encode(compressed_value)
            return index, ValueType.COMPRESSED_AES.value, compressed_aesed_value
        else:
            index, _, aesed_value = super(OptimalAES, self).encode(value)
            return index, ValueType.AES.value, aesed_value


@lazy_loadable()
class ValueManager(object):
    def __init__(self, config, keys=None):
        if not keys:
            keys = config['application']['keys']
        self.encoding_algorithm = OptimalAES(keys)
        self.decoding_algorithms = {
            ValueType.AES.value: AES(keys),
            ValueType.COMPRESSED_AES.value: CompressedAES(keys),
        }

    def encode(self, value):
        return self.encoding_algorithm.encode(value)

    def decode(self, index, algorithm, value):
        return self.decoding_algorithms[algorithm].decode(index, value)


def get_value_manager(config):
    return LazyLoader.get_instance('ValueManager', config=config)
