import sys
import time

from ya.skynet.util.functional import singleton
from ya.skynet.library.auth.key import Certificate
from ya.skynet.library.auth.verify import VerifyManager
from ya.skynet.library.portoshell import hash_token

from .slots.exceptions import AuthError

import paramiko


if sys.version_info.major < 3:
    bhex = lambda b: b.encode('hex')

    def b(s):
        return s.encode('utf-8') if isinstance(s, unicode) else str(s)  # noqa
else:
    bhex = bytes.hex

    def b(s):
        return s if isinstance(s, bytes) else bytes(str(s), 'utf-8')


@singleton
def vm():
    return VerifyManager()


def check_certificate(log, ca_storage, user, slot_info, cert):
    allowed_principals = slot_info.get_allowed_principals_for(user)
    allowed_users = list(set(allowed_principals) & set(cert.principals))
    allowed_login_users = [b'root', b'nobody'] + allowed_users
    log.debug("allowed principals: %s", allowed_principals)
    if not ca_storage.cert_valid(cert, allowed_cas=slot_info.get_allowed_ca_sources()):
        log.debug("certificate is either revoked or signed by not allowed authority")
        raise AuthError("Authentication as %s@%s failed" % (user, slot_info.identifier()))

    if (
        cert.is_user_certificate()
        # every cert is authorized for root and nobody for now
        and b(user) in allowed_login_users
        and allowed_users
    ):
        cert.userNames.add(user)
        return cert
    else:
        log.debug(
            'either %r not in allowed login users %r or principals %r are not allowed',
            b(user),
            allowed_login_users,
            cert.principals,
        )
        raise AuthError("Authentication as %s@%s failed" % (user, slot_info.identifier()))


def check_sign(key, token, sign):
    msg = paramiko.Message()
    msg.add_string(key.get_name())
    msg.packet.write(sign)
    msg.rewind()
    if key.verify_ssh_sig(token, msg):
        return key


def authenticate(
    log,
    user,
    slot_info,
    token=None,
    signs=None,
    fingerprint=None,
    keys_storage=None,
    ca_storage=None,
    maybe_certificate=None,
):
    if (token is None or signs is None) and fingerprint is None:
        raise AuthError("No authentication info provided")

    if token is not None and (time.time() - 600 > token['ctime'] or time.time() + 60 < token['ctime']):
        raise AuthError("Access security token is not valid or host has datetime misconfigured")

    if ca_storage is not None and maybe_certificate is not None and maybe_certificate.public_blob is not None:
        key_blob = maybe_certificate.public_blob.key_blob
        log.debug('checking if %r is a valid certificate', key_blob)
        try:
            cert = Certificate.fromNetworkRepresentation(key_blob, log=log)
        except Exception as e:
            log.exception('failed to load key as certificate: %s', e)
            # not a certificate, check it as plain key
            pass
        else:
            return check_certificate(log, ca_storage, user, slot_info, cert)

    user_keys = slot_info.get_auth_keys(user, keys_storage=keys_storage)
    log.debug('loaded keys: %s', [str(key) for key in user_keys])

    if fingerprint is not None:
        res = next(iter(filter(lambda key: fingerprint == key.fingerprint(), user_keys)), False)
    else:
        token = hash_token(token)
        certs = [
            record for record in signs if isinstance(record[0], (tuple, list))
        ] if ca_storage is not None else []
        for record in certs:
            try:
                cert = next(Certificate.loads(record[0][1], log=log), None)
                if cert is None:
                    continue

                sign = record[1]

                check_certificate(log, ca_storage, user, slot_info, cert)
                res = check_sign(cert, token, sign)
                if res is not None:
                    return res
            except AuthError:
                pass

        signs = {
            record[0]: record[1]
            for record in signs
            if not isinstance(record[0], (tuple, list))
        }
        res = None
        for key in user_keys:
            sign = signs.get(key.fingerprint())
            if sign is None:
                continue

            res = check_sign(key, token, sign)
            if res is not None:
                return res
    if not res:
        raise AuthError("Authentication as %s@%s failed" % (user, slot_info.identifier()))
    return res


def format_key(key):
    def hexfp(k, hash_function=None):
        if hasattr(k, 'hexFingerprint'):
            return k.hexFingerprint(hash_function=hash_function)
        else:
            # This branch must be reached only in tests
            assert hash_function is None
            return bhex(k.get_fingerprint())

    if isinstance(key, Certificate):
        fp_method='sha256'

        return '%s [%s %s:%s key_id %r serial %s CA %s %s:%s]' % (
            type(key).__name__,
            key.type,
            fp_method.upper(),
            hexfp(key, fp_method),
            key.key_id,
            key.serial,
            key.signing_key.type,
            fp_method.upper(),
            hexfp(key.signing_key, fp_method),
        )
    else:
        return '%s [%s]' % (
            type(key).__name__,
            hexfp(key),
        )
