import os
from base64 import b64decode, b64encode
from dataclasses import dataclass
from typing import Dict, Optional

from .base import BlockDecryptor, BlockEncryptor


@dataclass
class KeyVersion:
    key: bytes
    version: str


class VersionedKeyStorage:
    def __init__(self, versions: Dict[str, str]):
        self.versions = versions

    def version(self, version: str) -> KeyVersion:
        key = b64decode(self.versions[version].encode('ascii'))
        return KeyVersion(key=key, version=version)

    def latest(self) -> KeyVersion:
        return self.version(max(self.versions.keys()))

    @staticmethod
    def generate() -> str:
        return b64encode(os.urandom(32)).decode('ascii')


class VersionedManagedBlockEncryptor:
    def __init__(self, key_storage: VersionedKeyStorage) -> None:
        self._storage = key_storage
        self._encryptor: Optional[BlockEncryptor] = None

    def update(self, data: bytes) -> bytes:
        ret = b''
        if self._encryptor is None:
            key_version = self._storage.latest()
            self._encryptor = BlockEncryptor(key_version.key)
            iv = self._encryptor.iv

            # iv надо закодировать, чтобы в iv не мог встречаться токен ":"
            iv = b64encode(iv)

            ret = bytes(str(key_version.version), 'ascii') + b':' + iv + b':'
        return ret + self._encryptor.update(data)

    def finalize(self) -> bytes:
        assert self._encryptor is not None
        return self._encryptor.finalize()


class VersionedManagedBlockDecryptor:
    def __init__(self, key_storage: VersionedKeyStorage) -> None:
        self._storage = key_storage
        self._decryptor: Optional[BlockDecryptor] = None
        self._buf: bytes = b''

    def update(self, data: bytes) -> bytes:
        if self._decryptor is None:
            self._buf += data
            if self._buf.count(b':') < 2:
                return b''
            version, iv, data = self._buf.split(b':', 2)
            iv = b64decode(iv)

            key_version = self._storage.version(version.decode('ascii'))
            self._decryptor = BlockDecryptor(key=key_version.key, iv=iv)

        return self._decryptor.update(data)

    def finalize(self) -> bytes:
        assert self._decryptor is not None
        return self._decryptor.finalize()
