from typing import Dict

from cryptography.fernet import Fernet, InvalidToken


class CrypterBackend:
    def encrypt(self, data: bytes) -> str:
        raise NotImplementedError

    def decrypt(self, token: str) -> bytes:
        raise NotImplementedError


class FernetBackend(CrypterBackend):
    def __init__(self, key_versions: Dict[int, str]):
        assert len(key_versions) > 0, 'Cryptogram FernetBackend misconfigured'
        self.key_versions = key_versions
        self.last_version = max(key_versions.keys())

    def encrypt(self, data: bytes) -> bytes:
        kid = self.last_version
        key = self.key_versions[kid]
        fernet = Fernet(key)
        ciphertext = fernet.encrypt(data)
        return bytes(str(kid), 'ascii') + b'.' + ciphertext

    def decrypt(self, token: bytes) -> bytes:
        try:
            kid, ciphertext = token.split(b'.', maxsplit=1)
            key = self.key_versions[int(kid)]
        except (KeyError, ValueError):
            raise InvalidToken
        fernet = Fernet(key)
        return fernet.decrypt(ciphertext)


class CryptogramCrypter:
    def __init__(self, backends: Dict[int, CrypterBackend]):
        self.backends = backends
        self.last_backend_id = max(backends.keys())

    def encrypt(self, data: str) -> str:
        serialized = data.encode('utf-8')

        backend_id = self.last_backend_id
        backend = self.backends[backend_id]
        token_b = backend.encrypt(serialized)
        token = token_b.decode('utf-8')
        return f'{backend_id}.{token}'

    def decrypt(self, cryptogram: str) -> str:
        backend_id, token = cryptogram.split('.', maxsplit=1)
        backend = self.backends[int(backend_id)]
        data = backend.decrypt(token.encode('utf-8'))
        return data.decode('utf-8')

    def rotate(self, data: str) -> str:
        return self.encrypt(self.decrypt(data))
