import struct
import hashlib
import binascii

import six

from kernel.util.streams.streambase import StreamBase
from kernel.util.functional import memoized

from ._key_base import AuthKey, CryptoKey, bytesio, b, tobytes

try:
    from cEcdsa import Key as ECDSA
    _ecdsa_kind = 'torkve'
except ImportError:
    try:
        import paramiko
        from paramiko.ecdsakey import ECDSAKey as ECDSA, ec, default_backend as ECDSA_db
        from paramiko.ecdsakey import serialization as ECDSA_ser, hashes
        from cryptography.exceptions import InvalidSignature
        from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
        from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
        _ecdsa_kind = 'paramiko'
    except ImportError:
        ECDSA = None
        _ecdsa_kind = None


if _ecdsa_kind == 'torkve':
    class ECDSAKey(CryptoKey):
        """Wrapper around ECDSA keys"""
        cryptoKeyClass = ECDSA
        defaultSize = 521
        keyTypes = [
            b'ecdsa-sha2-nistp256',
            b'ecdsa-sha2-nistp384',
            b'ecdsa-sha2-nistp521',
        ]
        _sshPublicKeyPrefixes = keyTypes

        ecdsaHashes = {
            b'ecdsa-sha2-nistp256': hashlib.sha256,
            b'ecdsa-sha2-nistp384': hashlib.sha384,
            b'ecdsa-sha2-nistp521': hashlib.sha512,
        }

        @memoized
        def networkRepresentation(self):
            return self._key.public_key().to_raw()

        def dumpToNetworkStream(self, stream):
            raw_key_stream = StreamBase(bytesio(self._key.public_key().to_raw()))
            raw_key_stream.readBEStr()  # skip nid
            stream.writeBEStr(raw_key_stream.readBEStr())  # curve
            stream.writeBEStr(raw_key_stream.readBEStr())  # point

        @classmethod
        def fromNetworkStream(cls, stream, log=None):
            io = bytesio()
            out_stream = StreamBase(io)
            curve = stream.readBEStr()
            out_stream.writeBEStr(b'ecdsa-sha2-' + curve)
            out_stream.writeBEStr(curve)
            point = stream.readBEStr()
            out_stream.writeBEStr(point)
            return cls(cls.cryptoKeyClass.from_raw(io.getvalue()))

        @classmethod
        def fromNetworkRepresentation(cls, data, log=None):
            return cls(cls.cryptoKeyClass.from_raw(data))

        def size(self):
            return self._key.bits()

        def correctedHash(self, hash):
            return self.ecdsaHashes[b(self.type)](hash).digest()

        def sign(self, hash):
            return self._key.sign(hash)

        def verify(self, hash, sign):
            return self._key.verify(hash, sign)

        def encodeSign(self, sign):
            """[key-type][[r][s]] -> [[r][s]]"""
            stream = StreamBase(bytesio(sign))
            stream.readBEStr()  # drop key type
            sig = stream.readBEStr()

            io = bytesio()
            StreamBase(io).writeBEStr(sig)
            return io.getvalue()

        def decodeSign(self, sign):
            """[[r][s]] -> [key-type][[r][s]]"""
            sign = StreamBase(bytesio(sign)).readBEStr()
            io = bytesio()
            stream = StreamBase(io)
            stream.writeBEStr(self.type)
            stream.writeBEStr(sign)
            return io.getvalue()

        @property
        def type(self):
            return self._key.nid_name()

        def fingerprint(self, hash_function=None):
            if hash_function is not None:
                if isinstance(hash_function, six.string_types):
                    hash_function = getattr(hashlib, hash_function)
                return hash_function(self.networkRepresentation()).digest()

            return self._key.fingerprint()

        def hexFingerprint(self, hash_function=None):
            return bhex(self.fingerprint(hash_function=hash_function))

        @memoized
        def exportKey(self, keyformat='PEM'):
            if keyformat == 'OpenSSH':
                return b"%s %s %s" % (self._key.nid_name(), self._key.to_ssh(), self.comment)
            elif keyformat == 'PEM':
                return b(self._key.to_pem())

            raise NotImplementedError

        @memoized
        def publicKey(self):
            return self.__class__(self._key.public_key())

        @classmethod
        def _doLoads(cls, data, comment='', options=None, log=None):
            key = tobytes(data).strip()
            for keyType in cls._sshPublicKeyPrefixes:
                if key.startswith(keyType + b' '):
                    parts = key.split(b' ', 2)
                    key = cls(cls.cryptoKeyClass.from_ssh(parts[1]))
                    if len(parts) > 2 and parts[2]:
                        key.comment = parts[2]
                    else:
                        key.comment = comment
                    key.options = options or {}
                    return key
            if key.startswith(b'-----'):
                key = cls(cls.cryptoKeyClass.from_pem(key))
                key.comment = comment
                key.options = options or {}
                return key
            raise ValueError('ECDSA key format is not supported')

elif _ecdsa_kind == 'paramiko':
    class ECDSAKey(CryptoKey):
        defaultSize = 521
        keyTypes = [
            b'ecdsa-sha2-nistp256',
            b'ecdsa-sha2-nistp384',
            b'ecdsa-sha2-nistp521',
        ]
        _sshPublicKeyPrefixes = keyTypes

        ecdsaHashes = {
            b'ecdsa-sha2-nistp256': hashes.SHA256,
            b'ecdsa-sha2-nistp384': hashes.SHA384,
            b'ecdsa-sha2-nistp521': hashes.SHA512,
        }

        def __init__(self, *args):
            AuthKey.__init__(self)
            if args and hasattr(args[0], 'sign_ssh_data'):
                self._key = args[0]
            elif args:
                raise TypeError("You mustn't instantiate ECDSAkey object by numbers yourself, use ECDSAKey.construct")
            else:
                self._key = ECDSA.generate()

        @classmethod
        def construct(cls, curve, x, y, z=None, log=None):
            ecdsa_curve = ECDSA._ECDSA_CURVES.get_by_key_format_identifier(curve)
            if ecdsa_curve is None:
                raise TypeError("Invalid curve name")

            signing_key = None
            pub_numbers = ec.EllipticCurvePublicNumbers(x, y, ecdsa_curve.curve_class())
            verifying_key = pub_numbers.public_key(backend=ECDSA_db())

            if z is not None:
                priv_numbers = ec.EllipticCurvePrivateNumbers(z, pub_numbers)
                signing_key = priv_numbers.private_key(backend=ECDSA_db())

            key = ECDSA(vals=(signing_key, verifying_key))
            return cls(key)

        @classmethod
        def generate(cls, bits=None):
            """
            :param int bits: key bits
            :return: new generated key
            """
            if bits is None:
                bits = cls.defaultSize
            return cls(ECDSA.generate(bits=bits))

        def asbytes(self):
            return self._key.asbytes()

        def can_sign(self):
            return self._key.can_sign()

        def get_base64(self):
            return self._key.get_base64()

        def get_bits(self):
            return self._key.get_bits()

        def get_fingerprint(self):
            return self._key.get_fingerprint()

        def get_name(self):
            return self._key.get_name()

        def sign_ssh_data(self, data):
            return self._key.sign_ssh_data(data)

        def verify_ssh_sig(self, data, msg):
            return self._key.verify_ssh_sig(data, msg)

        def write_private_key(self, file_obj, password=None):
            return self._key.write_private_key(file_obj, password=password)

        def write_private_key_file(self, filename, password=None):
            return self._key.write_private_key_file(filename, password=password)

        has_private = hasPrivate = can_sign

        def fingerprint(self, hash_function=None):
            if hash_function is None:
                hash_function = self.hash
            elif isinstance(hash_function, six.string_types):
                hash_function = getattr(hashlib, hash_function)
            return hash_function(self._key.asbytes()).digest()

        def hexFingerprint(self, hash_function=None):
            return bhex(self.fingerprint(hash_function=hash_function))

        @memoized
        def networkRepresentation(self):
            return self.asbytes()

        def dumpToNetworkStream(self, stream):
            raw_key_stream = StreamBase(bytesio(self.asbytes()))
            raw_key_stream.readBEStr()  # skip nid
            stream.writeBEStr(raw_key_stream.readBEStr())  # curve
            stream.writeBEStr(raw_key_stream.readBEStr())  # point

        @classmethod
        def fromNetworkStream(cls, stream, log=None):
            msg = paramiko.Message()
            curve = stream.readBEStr()
            msg.add_string(b'ecdsa-sha2-' + curve)
            msg.add_string(curve)
            point = stream.readBEStr()
            msg.add_string(point)
            msg.rewind()
            return cls(ECDSA(msg=msg))

        @classmethod
        def fromNetworkRepresentation(cls, data, log=None):
            return cls(ECDSA(data=data))

        def correctedHash(self, hash):
            hasher = hashes.Hash(self.ecdsaHashes[b(self.type)]())
            hasher.update(hash)
            return hasher.finalize()

        def sign(self, hash):
            return self._key.signing_key.sign(
                hash,
                ec.ECDSA(Prehashed(self.ecdsaHashes[b(self.type)]())),
            )

        def verify(self, hash, sign):
            key = self.publicKey()._key.verifying_key
            try:
                key.verify(
                    sign,
                    hash,
                    ec.ECDSA(Prehashed(self.ecdsaHashes[b(self.type)]())),
                )
            except InvalidSignature:
                return False
            else:
                return True

        def encodeSign(self, sign):
            # cryptography returns signature in format different from
            # ssh-agent, so we should convert it here to make them match
            # DerSequence(DerInt[r], DerInt[s]) -> [[r][s]]
            r, s = decode_dss_signature(sign)
            io = bytesio()
            stream = StreamBase(io)
            stream.writeBELIntNetCap(r)
            stream.writeBELIntNetCap(s)
            sign = io.getvalue()
            io = bytesio()
            StreamBase(io).writeBEStr(sign)
            return io.getvalue()

        def decodeSign(self, sign):
            # convert ssh-agent format to cryptography-compatible format
            rs = StreamBase(bytesio(sign)).readBEStr()
            rs_stream = StreamBase(bytesio(rs))
            r = rs_stream.readBELInt()
            s = rs_stream.readBELInt()
            sign = encode_dss_signature(r, s)
            return sign

        @memoized
        def exportKey(self, keyformat='PEM'):
            if keyformat == 'OpenSSH':
                return b"%s %s %s" % (b(self.type), b(self._key.get_base64()), b(self.comment))
            elif keyformat == 'PEM':
                if self._key.can_sign():
                    key = self._key.signing_key
                    key_data = key.private_bytes(ECDSA_ser.Encoding.PEM,
                                                 ECDSA_ser.PrivateFormat.TraditionalOpenSSL,
                                                 ECDSA_ser.NoEncryption())
                else:
                    key = self._key.verifying_key
                    key_data = key.public_bytes(ECDSA_ser.Encoding.PEM,
                                                ECDSA_ser.PublicFormat.SubjectPublicKeyInfo)
                return b(key_data)

            raise NotImplementedError

        @memoized
        def publicKey(self):
            key = ECDSA(vals=(self._key.verifying_key, self._key.verifying_key))
            key.signing_key = None
            return self.__class__(key)

        size = get_bits

        @property
        def type(self):
            return self.get_name()

        @classmethod
        def _doLoads(cls, data, comment='', options=None, log=None):
            # FIXME
            key = tobytes(data).strip()
            for keyType in cls._sshPublicKeyPrefixes:
                if key.startswith(keyType + b' '):
                    parts = key.split(b' ', 2)
                    msg = paramiko.Message(binascii.a2b_base64(b(parts[1])))
                    key = cls(ECDSA(msg=msg))
                    if len(parts) > 2 and parts[2]:
                        key.comment = parts[2]
                    else:
                        key.comment = comment
                    key.options = options or {}
                    return key
            if key.startswith(b'-----'):
                try:
                    key = cls(ECDSA.from_private_key(six.moves.StringIO(u(key))))
                except paramiko.SSHException:
                    key = ECDSA_db().load_pem_public_key(key)
                    if not hasattr(key, 'curve'):
                        raise ValueError("key is not in ECDSA key format")
                    key = ECDSA(vals=(key, key))
                    key.signing_key = None
                    key = cls(key)
                key.comment = comment
                key.options = options or {}
                return key
            raise ValueError('ECDSA key format is not supported')

else:
    ECDSAKey = None


if six.PY2:
    u = lambda s: s
    bhex = lambda b: b.encode('hex')
else:
    u = lambda s: s.decode() if isinstance(s, bytes) else s
    bhex = bytes.hex
    long = int


def long_to_bytes(n, blocksize=0, pack=struct.pack):
    s = b''
    n = long(n)
    while n > 0:
        s = pack('>I', n & 0xffffffff) + s
        n = n >> 32
    # strip off leading zeros
    for i in six.moves.xrange(len(s)):
        if s[i] != b'\000'[0]:
            break
    else:
        # only happens when n == 0
        s = b'\000'
        i = 0
    s = s[i:]
    # add back some pad bytes.  this could be done more efficiently w.r.t. the
    # de-padding being done above, but sigh...
    if blocksize > 0 and len(s) % blocksize:
        s = (blocksize - len(s) % blocksize) * b'\000' + s
    return s
