import struct
import random
import logging
import binascii

from kernel.util.streams.streambase import StreamBase
from library.auth.key import Certificate, RSAKey, ECDSAKey
from library.auth.tempkey import TempKey

import six
import pytest


if six.PY2:
    bytesio = six.moves.cStringIO
    b = lambda s: s
else:
    bytesio = six.BytesIO
    b = lambda s: s.encode('latin-1') if isinstance(s, str) else s

int64 = struct.Struct('>Q')


def make_principals(principals):
    io = bytesio()
    stream = StreamBase(io)
    for principal in principals:
        stream.writeBEStr(b(principal))

    return io.getvalue()


def sign_data(key, data):
    sig = key.encodeSign(key.sign(key.correctedHash(data)))
    io = bytesio()
    stream = StreamBase(io)
    stream.writeBEStr(key.type)
    stream.write(sig)
    return io.getvalue()


def verify_sig(key, data, sig_msg):
    stream = StreamBase(bytesio(sig_msg))
    key_type = stream.readBEStr()
    assert key_type == b(key.type)

    data_hash = key.correctedHash(data)
    sign = key.decodeSign(stream.read())
    return key.verify(data_hash, sign)


@pytest.fixture(scope='function', autouse=True)
def setup_logging():
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if not any(
        isinstance(handler, logging.StreamHandler)
        for handler in logger.handlers
    ):
        logger.addHandler(logging.StreamHandler())


key_types = [RSAKey] + ([] if ECDSAKey is None else [ECDSAKey])


@pytest.mark.parametrize('ca_key_type', key_types)
@pytest.mark.parametrize('user_key_type', key_types)
def test_certificate_check(ca_key_type, user_key_type):
    nonce = str(random.randint(0, 1 << 64)).encode()
    serial = random.randint(0, 1 << 64)
    valid_after = 0
    valid_before = (1 << 64) - 1

    principals = [b'torkve']
    principals_string = make_principals(principals)

    with TempKey(keyClass=user_key_type) as user_key:
        with TempKey(keyClass=ca_key_type) as ca_key:
            key_type = b(user_key.type) + b'-cert-v01@openssh.com'
            io = bytesio()
            stream = StreamBase(io)
            stream.writeBEStr(key_type)
            stream.writeBEStr(b(nonce))
            user_key.publicKey().dumpToNetworkStream(stream)
            stream.write(int64.pack(serial))
            stream.writeBEInt(1)  # SSH_CERT_TYPE_USER
            stream.writeBEStr(b"vasya's key")
            stream.writeBEStr(principals_string)
            stream.write(int64.pack(valid_after))
            stream.write(int64.pack(valid_before))
            stream.writeBEStr(b"")  # critical_options
            stream.writeBEStr(b"")  # extensions
            stream.writeBEStr(b"")  # reserved
            stream.writeBEStr(ca_key.publicKey().asbytes())

            signed_block = io.getvalue()
            signature = sign_data(ca_key, signed_block)
            stream.writeBEStr(signature)

            cert_data = key_type + b' ' + binascii.b2a_base64(io.getvalue())

            cert = Certificate.fromNetworkRepresentation(io.getvalue())

            assert isinstance(cert, Certificate)
            assert cert.fingerprint() == user_key.fingerprint()
            assert cert.signing_key.fingerprint() == ca_key.fingerprint()
            assert cert.serial == serial
            assert cert.is_user_certificate()
            assert cert.principals == principals
            assert cert.valid_range == (valid_after, valid_before)
            assert not cert.critical_options
            assert not cert.extensions
            assert cert.check_validity()

            sig = sign_data(user_key, cert_data)
            assert verify_sig(cert, cert_data, sig), '\n'.join((
                'Private signing key: %r' % (user_key.asbytes(),),
                'Public verifying key: %r' % (cert._key.asbytes(),),
                'Data: %r' % (cert_data,),
                'Signature: %r' % (sig,),
            ))
