import struct

import six

from cryptography.hazmat.primitives.asymmetric import dsa, rsa, padding
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.backends import default_backend
from cryptography.exceptions import InvalidSignature

from kernel.util.streams.streambase import StreamBase
from kernel.util.functional import memoized

from ._key_base import AuthKey, CryptoKey, bytesio, tobytes

from . import log as root


if six.PY3:
    long = int


class RSAKey(CryptoKey):
    keyType = b'ssh-rsa'
    _sshPublicKeyPrefixes = (keyType,)

    defaultSize = 2048

    def __init__(self, key):
        AuthKey.__init__(self)
        self._key = key

    @classmethod
    def generate(cls, bits=None):
        if bits is None:
            bits = cls.defaultSize

        return cls(rsa.generate_private_key(
            public_exponent=65537,
            key_size=bits,
            backend=default_backend(),
        ))

    @classmethod
    def construct(cls, *args, **kwargs):
        log = kwargs.pop('log', None) or root.getChild('key')
        try:
            n, e = args[:2]
            public_numbers = rsa.RSAPublicNumbers(n=n, e=e)
            if len(args) == 2:
                return cls(public_numbers.public_key(default_backend()))

            d, iqmp, q, p = args[2:]
            private_numbers = rsa.RSAPrivateNumbers(
                d=d,
                iqmp=iqmp,
                q=q,
                p=p,
                dmp1=p % (p - 1),
                dmq1=d % (q - 1),
                public_numbers=public_numbers,
            )

            return cls(private_numbers.private_key(default_backend()))
        except cls.KeyLoadErrors:
            log.warning('Malformed key `{0}`'.format(', '.join(args)))

    def size(self):
        return self._key.key_size

    def has_private(self):
        return isinstance(self._key, rsa.RSAPrivateKey)

    hasPrivate = has_private

    def publicKey(self):
        if self.has_private():
            return self.__class__(self._key.public_key())
        else:
            return self

    def sign(self, hash):
        sign = self._key.sign(
            data=hash,
            padding=padding.PKCS1v15(),
            algorithm=hashes.SHA1(),
        )
        # NOTE: due to py23 support we cannot use int.from_bytes(sig, 'big')
        return bytes_to_long(sign)

    def verify(self, hash, sign):
        block_padding = self.get_bits() % 8
        block_padding = 8 - block_padding if block_padding > 0 else 0

        sign = long_to_bytes(sign, (self.get_bits() + block_padding) // 8)

        key = self.publicKey()._key
        try:
            key.verify(
                signature=sign,
                data=hash,
                padding=padding.PKCS1v15(),
                algorithm=hashes.SHA1(),
            )
        except InvalidSignature:
            return False
        else:
            return True

    def encodeSign(self, sign):
        if isinstance(sign, tuple):
            sign = sign[0]
        io = bytesio()
        StreamBase(io).writeBELInt(sign)
        return io.getvalue()

    def decodeSign(self, sign):
        return StreamBase(bytesio(sign)).readBELInt()

    @memoized
    def exportKey(self, keyformat='PEM'):
        encoding = serialization.Encoding(keyformat)
        if encoding in (serialization.Encoding.PEM, serialization.Encoding.DER):
            if self.has_private():
                data = self._key.private_bytes(
                    encoding=encoding,
                    format=serialization.PrivateFormat.TraditionalOpenSSL,
                    encryption_algorithm=serialization.NoEncryption(),
                )
            else:
                data = self._key.public_bytes(
                    encoding=encoding,
                    format=serialization.PublicFormat.SubjectPublicKeyInfo,
                )
        elif encoding == serialization.Encoding.OpenSSH:
            key = self._key
            if self.has_private():
                key = key.public_key()
            data = key.public_bytes(
                encoding=encoding,
                format=serialization.PublicFormat.OpenSSH,
            )
        else:
            raise NotImplementedError("Export as {!r} not supported".format(keyformat))

        return data

    def dumpToNetworkStream(self, stream):
        if self.has_private():
            public_numbers = self._key.private_numbers().public_numbers
        else:
            public_numbers = self._key.public_numbers()
        stream.writeBELIntNetCap(public_numbers.e)
        stream.writeBELIntNetCap(public_numbers.n)

    def correctedHash(self, hash):
        return hash

    @classmethod
    def _doLoads(cls, data, comment='', options=None, log=None):
        bdata = tobytes(data).strip()
        try:
            if bdata.startswith(b'-----'):
                if b' PRIVATE KEY' in bdata:
                    key = serialization.load_pem_private_key(bdata, None, default_backend())
                else:
                    key = serialization.load_pem_public_key(bdata, default_backend())
                result = cls(key)
            elif bdata.startswith(cls.keyType + b' '):
                parts = bdata.split(b' ', 2)
                key = serialization.load_ssh_public_key(bdata, default_backend())
                result = cls(key)
                if len(parts) > 2 and parts[2]:
                    result.comment = parts[2]
            else:
                raise NotImplementedError()

            if not isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
                raise ValueError()
        except (ValueError, IndexError, NotImplementedError):
            raise ValueError('key is not in RSA key format')

        if not result.comment:
            result.comment = comment
        result.options = options or {}
        return result

    @classmethod
    def fromNetworkStream(cls, stream, log=None):
        """This method assumes that key type has already been read
        from the stream and there's only key data left"""
        e = stream.readBELInt()
        n = stream.readBELInt()
        key = rsa.RSAPublicNumbers(e=e, n=n).public_key(default_backend())
        return cls(key)


class DSSKey(CryptoKey):
    keyType = b'ssh-dss'
    _sshPublicKeyPrefixes = (keyType,)

    defaultSize = 1024

    def __init__(self, key):
        AuthKey.__init__(self)
        self._key = key

    @classmethod
    def generate(cls, bits=None):
        if bits is None:
            bits = cls.defaultSize

        return cls(dsa.generate_private_key(
            key_size=bits,
            backend=default_backend(),
        ))

    @classmethod
    def construct(cls, *args, **kwargs):
        log = kwargs.pop('log', None) or root.getChild('key')
        try:
            p, q, g, y = args[:4]
            public_numbers = dsa.DSAPublicNumbers(
                y=y,
                parameter_numbers=dsa.DSAParameterNumbers(p=p, q=q, g=g),
            )

            if len(args) == 4:
                return cls(public_numbers.public_key(default_backend()))

            x = args[5]
            private_numbers = dsa.DSAPrivateNumbers(
                x=x,
                public_numbers=public_numbers,
            )

            return cls(private_numbers.private_key(default_backend()))
        except cls.KeyLoadErrors:
            log.warning('Malformed key `{0}`'.format(', '.join(args)))

    def size(self):
        return self._key.key_size

    def has_private(self):
        return isinstance(self._key, dsa.DSAPrivateKey)

    hasPrivate = has_private

    def publicKey(self):
        if self.has_private():
            return self.__class__(self._key.public_key())
        else:
            return self

    def sign(self, hash):
        sign = self._key.sign(
            data=hash,
            algorithm=hashes.SHA1(),
        )
        return decode_dss_signature(sign)

    def verify(self, hash, sign):
        key = self.publicKey()._key
        r, s = sign
        sign = encode_dss_signature(r, s)

        try:
            key.verify(
                signature=sign,
                data=hash,
                algorithm=hashes.SHA1(),
            )
        except InvalidSignature:
            return False
        else:
            return True

    def encodeSign(self, sign):
        _sign = long_to_bytes(sign[0], 20) + long_to_bytes(sign[1], 20)
        io = bytesio()
        StreamBase(io).writeBEStr(_sign)
        return io.getvalue()

    def decodeSign(self, sign):
        _sign = StreamBase(bytesio(sign)).readBEStr()
        return bytes_to_long(_sign[:20]), bytes_to_long(_sign[20:])

    @memoized
    def exportKey(self, keyformat='PEM'):
        encoding = serialization.Encoding(keyformat)
        if encoding in (serialization.Encoding.PEM, serialization.Encoding.DER):
            if self.has_private():
                data = self._key.private_bytes(
                    encoding=encoding,
                    format=serialization.PrivateFormat.TraditionalOpenSSL,
                    encryption_algorithm=serialization.NoEncryption(),
                )
            else:
                data = self._key.public_bytes(
                    encoding=encoding,
                    format=serialization.PublicFormat.SubjectPublicKeyInfo,
                )
        elif encoding == serialization.Encoding.OpenSSH:
            key = self._key
            if self.has_private():
                key = key.public_key()
            data = key.public_bytes(
                encoding=encoding,
                format=serialization.PublicFormat.OpenSSH,
            )
        else:
            raise NotImplementedError("Export as {!r} not supported".format(keyformat))

        return data

    def dumpToNetworkStream(self, stream):
        if self.has_private():
            public_numbers = self._key.private_numbers().public_numbers
        else:
            public_numbers = self._key.public_numbers()
        stream.writeBELIntNetCap(public_numbers.parameter_numbers.p)
        stream.writeBELIntNetCap(public_numbers.parameter_numbers.q)
        stream.writeBELIntNetCap(public_numbers.parameter_numbers.g)
        stream.writeBELIntNetCap(public_numbers.y)

    def correctedHash(self, hash):
        return hash

    @classmethod
    def _doLoads(cls, data, comment='', options=None, log=None):
        bdata = tobytes(data).strip()
        try:
            if bdata.startswith(b'-----'):
                if b' PRIVATE KEY' in bdata:
                    key = serialization.load_pem_private_key(bdata, None, default_backend())
                else:
                    key = serialization.load_pem_public_key(bdata, default_backend())
                result = cls(key)
            elif bdata.startswith(cls.keyType + b' '):
                parts = bdata.split(b' ', 2)
                key = serialization.load_ssh_public_key(bdata, default_backend())
                result = cls(key)
                if len(parts) > 2 and parts[2]:
                    result.comment = parts[2]
            else:
                raise NotImplementedError()

            if not isinstance(key, (dsa.DSAPrivateKey, dsa.DSAPublicKey)):
                raise ValueError()
        except (ValueError, IndexError, NotImplementedError):
            raise ValueError('key is not in DSA key format')

        if not result.comment:
            result.comment = comment
        result.options = options or {}
        return result

    @classmethod
    def fromNetworkStream(cls, stream, log=None):
        """This method assumes that key type has already been read
        from the stream and there's only key data left"""
        p = stream.readBELInt()
        q = stream.readBELInt()
        g = stream.readBELInt()
        y = stream.readBELInt()
        return cls.construct(p, q, g, y, log=log)


def bytes_to_long(bytestring, unpack=struct.unpack):
    acc = 0
    length = len(bytestring)
    if length % 4:
        extra = (4 - length % 4)
        bytestring = b'\000' * extra + bytestring
        length = length + extra
    for i in six.moves.xrange(0, length, 4):
        acc = (acc << 32) + unpack('>I', bytestring[i:i + 4])[0]
    return acc


def long_to_bytes(n, blocksize=0, pack=struct.pack):
    s = b''
    n = long(n)
    while n > 0:
        s = pack('>I', n & 0xffffffff) + s
        n = n >> 32
    # strip off leading zeros
    for i in six.moves.xrange(len(s)):
        if s[i] != b'\000'[0]:
            break
    else:
        # only happens when n == 0
        s = b'\000'
        i = 0
    s = s[i:]
    # add back some pad bytes.  this could be done more efficiently w.r.t. the
    # de-padding being done above, but sigh...
    if blocksize > 0 and len(s) % blocksize:
        s = (blocksize - len(s) % blocksize) * b'\000' + s
    return s
