import binascii
from hashlib import sha1

import six

from Crypto.PublicKey import RSA, DSA
from Crypto.Random.random import randint
from Crypto.Util import asn1
from Crypto.Util.number import bytes_to_long, long_to_bytes, bignum

from kernel.util.streams.streambase import StreamBase
from kernel.util.functional import memoized

from ._key_base import CryptoKey, bytesio, tobytes


CryptoKey.KeyLoadErrors = CryptoKey.KeyLoadErrors + (RSA.error, DSA.error)

if six.PY3:
    long = int


class RSAKey(CryptoKey):
    """PublicKey.RSA wrapper"""

    keyType = b'ssh-rsa'
    _sshPublicKeyPrefixes = (keyType,)

    cryptoKeyClass = RSA
    defaultSize = 2048

    def sign(self, hash):
        return self._key.sign(hash, None)[0]

    def verify(self, hash, sign):
        return self._key.verify(hash, (sign,))

    def encodeSign(self, sign):
        if isinstance(sign, tuple):
            sign = sign[0]
        io = bytesio()
        StreamBase(io).writeBELInt(sign)
        return io.getvalue()

    def decodeSign(self, sign):
        return long(StreamBase(bytesio(sign)).readBELInt())

    @memoized
    def exportKey(self, keyformat='PEM'):
        return self._key.exportKey(format=keyformat)

    def dumpToNetworkStream(self, stream):
        stream.writeBELIntNetCap(self._key.e)
        stream.writeBELIntNetCap(self._key.n)

    def correctedHash(self, hash):
        return _PKCS1Digest(hash, self.size() // 8)

    @classmethod
    def _doLoads(cls, data, comment='', options=None, log=None):
        result = cls(cls.cryptoKeyClass.importKey(data.strip()))
        parts = data.strip().split(b' ', 2)
        if len(parts) > 2 and parts[0] == cls.keyType and parts[2]:
            result.comment = parts[2]
        if not result.comment:
            result.comment = comment
        result.options = options or {}
        return result

    @classmethod
    def fromNetworkStream(cls, stream, log=None):
        """This method assumes that key type has already been read
        from the stream and there's only key data left"""
        e = bignum(stream.readBELInt())
        n = bignum(stream.readBELInt())
        return cls.construct((n, e), log=log)


class DSSKey(CryptoKey):
    """Wrapper around DSS key"""

    keyType = b'ssh-dss'
    _sshPublicKeyPrefixes = (keyType,)

    cryptoKeyClass = DSA
    defaultSize = 1024

    def correctedHash(self, hash):
        return bytes_to_long(sha1(hash).digest())

    def sign(self, hash, k=None):
        if k is None:
            k = randint(2, self._key.q - 1)
        return self._key.sign(hash, k)

    def verify(self, hash, sign):
        return self._key.verify(hash, sign)

    def encodeSign(self, sign):
        _sign = long_to_bytes(sign[0], 20) + long_to_bytes(sign[1], 20)
        io = bytesio()
        StreamBase(io).writeBEStr(_sign)
        return io.getvalue()

    def decodeSign(self, sign):
        _sign = StreamBase(bytesio(sign)).readBEStr()
        return bytes_to_long(_sign[:20]), bytes_to_long(_sign[20:])

    @memoized
    def exportKey(self, keyformat='PEM'):
        if keyformat == 'OpenSSH':
            io = bytesio()
            stream = StreamBase(io)
            stream.writeBEStr(self.keyType)
            self.dumpToNetworkStream(stream)
            return b'%s %s' % (self._sshPublicKeyPrefixes[0], binascii.b2a_base64(io.getvalue()))

        seq = asn1.DerSequence([0, self._key.p, self._key.q, self._key.g, self._key.y] + ([self._key.x] if self.hasPrivate() else []))
        if keyformat == 'PEM':
            header = b'PRIVATE KEY' if self.hasPrivate() else b'PUBLIC KEY'
            return b'-----BEGIN DSA %(header)s-----\n%(data)s-----END DSA %(header)s-----' % {
                b'header': header,
                b'data': binascii.b2a_base64(seq.encode())
            }
        elif keyformat == 'DER':
            return seq.encode()

        raise NotImplementedError

    def dumpToNetworkStream(self, stream):
        stream.writeBELIntNetCap(self._key.p)
        stream.writeBELIntNetCap(self._key.q)
        stream.writeBELIntNetCap(self._key.g)
        stream.writeBELIntNetCap(self._key.y)

    @classmethod
    def fromNetworkStream(cls, stream, log=None):
        p, q, g, y = (
            stream.readBELInt(),
            stream.readBELInt(),
            stream.readBELInt(),
            stream.readBELInt(),
        )
        return cls.construct((y, g, p, q), log=log)

    @classmethod
    def _doLoads(cls, data, comment='', options=None, log=None):
        key = tobytes(data).strip()
        if key.startswith(cls.keyType + b' '):
            parts = key.split(b' ', 2)
            keyString = binascii.a2b_base64(parts[1])
            stream = StreamBase(bytesio(keyString))
            if cls.keyType != stream.readBEStr():
                raise DSA.error('Corrupted key')
            key = cls.fromNetworkStream(stream, log=log)
            if (len(parts) > 2) and parts[2]:
                key.comment = parts[2]
            else:
                key.comment = comment
            key.options = options or {}
            return key
        elif key.startswith(b'-----'):
            lines = key.strip().split(b'\n')
            key = binascii.a2b_base64(b''.join(lines[1:-1]))
            seq = asn1.DerSequence()
            seq.decode(key, True)
            fields = seq[1:]  # (p, q, g, y, x)
            fieldsLen = len(fields)
            if fieldsLen < 4 or fieldsLen > 5:  # Another key type
                raise DSA.error('Corrupted key')
            fields[0:4] = (fields[3], fields[2], fields[0], fields[1])  # (y, g, p, q, x)
            result = cls.construct(fields, log=log)
            result.comment = comment
            result.options = options or {}
            return result
        raise DSA.error('DSA key format is not supported')


def _PKCS1Pad(data, messageLength):
    lenPad = messageLength - 2 - len(data)
    return b'\x01' + (b'\xff' * lenPad) + b'\x00' + data


def _PKCS1Digest(data, messageLength, idSha1=b'\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'):
    digest = sha1(data).digest()
    return _PKCS1Pad(idSha1 + digest, messageLength)
