import logging
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, ContextManager, List, Optional

import simplejson as json
from cryptography.fernet import Fernet
from cryptography.fernet import InvalidToken as FernetInvalidToken
from cryptography.fernet import MultiFernet

from mail.payments.payments.conf import settings

log = logging.getLogger(__name__)

log.addHandler(logging.StreamHandler())


class DecryptionError(RuntimeError):
    pass


class InvalidToken(RuntimeError):
    pass


class V1Crypto(MultiFernet):
    ORDER_URL_KIND = 'order'
    PAYMENT_URL_KIND = 'payment'
    MSG_TTL = {
        ORDER_URL_KIND: 30 * 24 * 3600,
        PAYMENT_URL_KIND: 24 * 3600,
    }

    def __init__(self, keys: List[bytes], proto_version: int):
        super(V1Crypto, self).__init__(map(Fernet, keys))  # type: ignore
        self.version = proto_version

    def encrypt_order(self, uid: int, order_id: int) -> str:
        return self._encrypt(
            dict(uid=uid, order_id=order_id, url_kind=self.ORDER_URL_KIND, version=self.version)
        )

    def encrypt_payment(self, uid: int, order_id: int) -> str:
        return self._encrypt(
            dict(uid=uid, order_id=order_id, url_kind=self.PAYMENT_URL_KIND, version=self.version)
        )

    def _encrypt(self, data: dict) -> str:
        return self.encrypt(json.dumps(data).encode('utf-8')).decode()

    def decrypt_order(self, msg: str) -> dict:
        return self._decrypt(msg, self.ORDER_URL_KIND)

    def decrypt_payment(self, msg: str) -> dict:
        return self._decrypt(msg, self.PAYMENT_URL_KIND)

    def decrypt_msg(self, msg: str) -> dict:
        return self._decrypt(msg)

    def _decrypt(self, msg: str, url_kind: Optional[str] = None) -> dict:
        ttl = self.MSG_TTL[url_kind] if url_kind else None
        try:
            data = json.loads(
                self.decrypt(msg.encode('utf-8'), ttl=ttl).decode()
            )
        except FernetInvalidToken:
            raise InvalidToken
        except json.JSONDecodeError as e:
            log.exception(e)
            raise DecryptionError('Cannot JSON-decode cryptotext: %s' % msg)
        if data.get('version') != self.version:
            raise DecryptionError('Actual version is %s, expected: %s' % (data.get('version'), self.version))
        if url_kind is not None and data.get('url_kind') != url_kind:
            raise DecryptionError('Wrong data in cryptotext: %s' % data)
        return data


class Crypto:
    def __init__(self, keys: List[bytes]):
        self.cryptos = OrderedDict([
            (1, V1Crypto(keys, proto_version=1))
        ])

    @classmethod
    def from_file(cls, keys_path: str) -> 'Crypto':
        with open(keys_path) as keys_file:
            return cls([line.strip().encode('utf-8') for line in keys_file.readlines()])

    @property
    def crypto(self):
        return list(self.cryptos.values())[-1]

    def encrypt_order(self, uid: int, order_id: int) -> str:
        return self.crypto.encrypt_order(uid=uid, order_id=order_id)

    def encrypt_payment(self, uid: int, order_id: int) -> str:
        return self.crypto.encrypt_payment(uid=uid, order_id=order_id)

    def decrypt_order(self, msg: str) -> ContextManager[dict]:
        @contextmanager
        def wrapper():
            for version, crypto_impl in reversed(self.cryptos.items()):
                try:
                    yield self.crypto.decrypt_order(msg)
                    return
                except DecryptionError as e:
                    log.warning('Decryption using version %d failed, details:', version)
                    log.exception(e)
            raise DecryptionError('No suitable decryption found for message <%s>' % msg)

        return wrapper()

    def decrypt_payment(self, msg: str) -> ContextManager[dict]:
        @contextmanager
        def wrapper():
            yield self.crypto.decrypt_payment(msg)

        return wrapper()

    def decrypt(self, msg: str) -> ContextManager[dict]:
        @contextmanager
        def wrapper():
            yield self.crypto.decrypt_msg(msg)

        return wrapper()

    def with_external_urls(self, entity: dict, **entity_ids: Any) -> dict:
        return dict(
            # Urls are deprecated
            order_url=settings.CRYPTO_V1_F1_PREFIX + self.encrypt_order(**entity_ids),
            payment_url=settings.CRYPTO_V1_F1_PREFIX + self.encrypt_payment(**entity_ids),
            order_hash=self.encrypt_order(**entity_ids),
            payment_hash=self.encrypt_payment(**entity_ids),
            **entity
        )
