import cryptography.exceptions
from cryptography.hazmat.primitives.asymmetric import ec
import jwt.algorithms
from jwt.api_jwt import _jwt_global_obj
from passport.backend.utils.gost.cryptography import (
    GostR3411_2012_256,
    GostR3411_2012_512,
)


GOST_ALGORITHMS = {
    'GOST3410_2012_256': GostR3411_2012_256,
    'GOST3410_2012_512': GostR3411_2012_512,
}


def install_gost_to_jws(_jwt=None):
    if _jwt is None:
        _jwt = _jwt_global_obj
    known_algorithms = set(_jwt.get_algorithms())
    for alg_name, hash_func in GOST_ALGORITHMS.items():
        if alg_name not in known_algorithms:
            _jwt.register_algorithm(alg_name, GostAlgorithm(hash_func))


class GostAlgorithm(jwt.algorithms.ECAlgorithm):
    def sign(self, msg, key):
        return key.sign(msg, ec.ECDSA(self.hash_alg()))

    def verify(self, msg, key, sig):
        try:
            key.verify(sig, msg, ec.ECDSA(self.hash_alg()))
            return True
        except cryptography.exceptions.InvalidSignature:
            return False
