# -*- coding: utf-8 -*-

from hashlib import sha256

from passport.backend.utils.common import to_base64_url


PKCE_METHOD_S256 = 'S256'
PKCE_METHOD_PLAIN = 'plain'

_ALL_METHODS = [
    PKCE_METHOD_S256,
    PKCE_METHOD_PLAIN,
]

_NORMAL_PKCE_METHOD_TABLE = {m.upper(): m for m in _ALL_METHODS}


def is_valid_pkce_method(method):
    return method.upper() in _NORMAL_PKCE_METHOD_TABLE


def fix_pkce_method_case(method):
    return _NORMAL_PKCE_METHOD_TABLE[method.upper()]


def check_pkce(verifier, challenge, method, allow_empty=False):
    if not challenge and not method and allow_empty:
        if not verifier:
            return
        raise PkceInvalidCodeVerifierError()

    if method == PKCE_METHOD_S256:
        challenge = challenge.rstrip('=')
    elif method == PKCE_METHOD_PLAIN:
        pass
    else:
        raise NotImplementedError()  # pragma: no cover

    verifier = build_code_challenge(verifier, method)

    if verifier != challenge:
        raise PkceInvalidCodeVerifierError()


def build_code_challenge(verifier, method):
    method = method.upper()
    method = _NORMAL_PKCE_METHOD_TABLE.get(method, method)
    if method == PKCE_METHOD_S256:
        challenge = s256(verifier)
    elif method == PKCE_METHOD_PLAIN:
        challenge = verifier
    else:
        raise NotImplementedError()  # pragma: no cover
    return challenge


def s256(s):
    return to_base64_url(sha256(s).digest())


class PkceInvalidCodeVerifierError(Exception):
    pass
