import hmac
import hashlib
import base64
import datetime as dt

from typing import AnyStr
from Crypto.Cipher import AES

from intranet.domenator.src.settings import config
from .types import PDDVersion


class BaseToken(object):

    def _generate_raw_token(self, sha, *args):
        messstr = ''
        for i, arg in enumerate(args, 1):
            messstr += str(arg) + '{} |{}|'.format(args, i)
        messstr += str(dt.datetime.now())
        messstr = messstr.encode('utf-8')
        return hmac.new(config.registrar_get_token_secret, messstr, sha)

    def serialize_token(self, token):
        return token

    def unserialize_token(self, token):
        return token

    def generate_token(self, *args):
        raise NotImplementedError


class PlainToken(BaseToken):
    """
    Для проcтых токенов
    Это токены доменов для v1 апи ПДД
    """

    def generate_token(self, *args):
        return self._generate_raw_token(hashlib.sha224, *args).hexdigest()


class CryptToken(BaseToken):
    """
    Для шифрованных токенов
    Это токены регистраторов и доменов для v2 апи ПДД
    """

    def generate_token(self, *args):
        raw_token = self._generate_raw_token(hashlib.sha256, *args).digest()
        return base64.b32encode(raw_token).rstrip(b'=')

    def serialize_token(self, token):
        cipher = AES.new(config.registrar_store_token_secret, AES.MODE_ECB)
        ljust_value = b'='
        if isinstance(token, str):
            ljust_value = '='
        raw_token = base64.b32decode(token.upper().ljust(56, ljust_value))
        return base64.b32encode(cipher.encrypt(raw_token)).decode('utf-8')

    def unserialize_token(self, token):
        cipher = AES.new(config.registrar_store_token_secret, AES.MODE_ECB)
        dec_token = cipher.decrypt(base64.b32decode(token))
        return base64.b32encode(dec_token).rstrip(b'=')


def _get_token_class_by_version(pdd_version: PDDVersion) -> BaseToken:
    if pdd_version == PDDVersion.new:
        return CryptToken()
    return PlainToken()


def serialize_token(pdd_version: PDDVersion, token: AnyStr) -> str:
    token_class = _get_token_class_by_version(pdd_version)
    return token_class.serialize_token(token)


def unserialize_token(pdd_version: PDDVersion, token: AnyStr) -> bytes:
    token_class = _get_token_class_by_version(pdd_version)
    return token_class.unserialize_token(token)


def generate_token(pdd_version: PDDVersion, *args) -> bytes:
    token_class = _get_token_class_by_version(pdd_version)
    return token_class.generate_token(*args)
