import time
import struct
import binascii

from ._key_crypto import RSAKey
# we do know, that ecdsa implementation is torkve one
from ._key_ecdsa import ECDSAKey
from ._key_base import CryptoKey, AuthKey

from kernel.util.streams.streambase import StreamBase

import six
import enum


int64 = struct.Struct('>Q')

CERT_TIME_INFINITY = (1 << 64) - 1

if six.PY2:
    bytesio = six.moves.cStringIO
    b = lambda s: s
    u = lambda s: s
    bhex = lambda b: b.encode('hex')

    def tobytes(s):
        if isinstance(s, six.text_type):
            return s.encode('latin-1')
        else:
            return ''.join(s)

else:
    bytesio = six.BytesIO
    b = lambda s: s.encode('latin-1') if isinstance(s, str) else s
    u = lambda s: s.decode() if isinstance(s, bytes) else s
    bhex = bytes.hex

    def tobytes(s):
        if isinstance(s, str):
            return bytes(s, 'latin-1')
        else:
            return bytes(s)


class CertificateType(enum.IntEnum):
    SSH_CERT_TYPE_USER = 1
    SSH_CERT_TYPE_HOST = 2


def _get_remainder(stream):
    position = stream.slave.tell()
    remainder = stream.slave.read()
    stream.slave.seek(position)
    return remainder


def _get_so_far(stream):
    position = stream.slave.tell()
    stream.slave.seek(0)
    return stream.slave.read(position)


def _readBEInt64(stream):
    return int64.unpack(stream.read(8))[0]


def parse_string_list(packed_string):
    stream = StreamBase(bytesio(packed_string))
    items = []
    while _get_remainder(stream):
        items.append(stream.readBEStr())

    return items


def parse_dict(packed_string):
    stream = StreamBase(bytesio(packed_string))
    ret = {}
    while _get_remainder(stream):
        k = u(stream.readBEStr())
        # FIXME (torkve) in fact, inner contents type may differ, see RFC
        v_raw = stream.readBEStr()
        if len(v_raw):
            v = u(StreamBase(bytesio(v_raw)).readBEStr())
        else:
            v = ''
        ret[k] = v

    return ret


def parse_signature(packed_string):
    stream = StreamBase(bytesio(packed_string))
    key_type = stream.readBEStr()
    sig = stream.read()
    return key_type, sig


class Certificate(CryptoKey):
    key_parsers = {
        b'ssh-rsa-cert-v01@openssh.com': RSAKey.fromNetworkStream,
        b'ecdsa-sha2-nistp256-cert-v01@openssh.com': ECDSAKey.fromNetworkStream,
        b'ecdsa-sha2-nistp384-cert-v01@openssh.com': ECDSAKey.fromNetworkStream,
        b'ecdsa-sha2-nistp521-cert-v01@openssh.com': ECDSAKey.fromNetworkStream,
    }
    keyTypes = _sshPublicKeyPrefixes = list(key_parsers.keys())

    def __init__(
        self,
        public_key,
        # nonce,  # do we need it?
        serial,
        cert_type,
        key_id,
        principals,
        valid_range,
        critical_options,
        extensions,
        # reserved,  # do we need it?
        signing_key,
        blob,
    ):
        AuthKey.__init__(self)
        self._key = public_key
        self.serial = serial
        self.cert_type = cert_type
        self.key_id = key_id
        self.principals = principals
        self.valid_range = valid_range
        self.critical_options = critical_options
        self.extensions = extensions
        self.signing_key = signing_key
        self.blob = blob

    @classmethod
    def generate(cls, bits=None, signing_key=None):
        if bits is None:
            bits = 2048

        public_key = RSAKey.generate(bits=bits)
        if signing_key is None:
            signing_key = public_key

        return cls(
            public_key=public_key,
            serial=0,
            cert_type=CertificateType.SSH_CERT_TYPE_USER,
            key_id='key',
            principals=[],
            valid_range=(0, CERT_TIME_INFINITY),
            critical_options={},
            extensions={},
            signing_key=signing_key,
            blob=b'',
        )

    def is_user_certificate(self):
        return self.cert_type == CertificateType.SSH_CERT_TYPE_USER

    def is_host_certificate(self):
        return self.cert_type == CertificateType.SSH_CERT_TYPE_HOST

    @classmethod
    def construct(self, *args, **kwargs):
        raise ValueError("Not supported")

    def asbytes(self):
        return self.blob

    def can_sign(self):
        return False

    def get_base64(self):
        raise NotImplementedError

    def get_bits(self):
        return self._key.get_bits()

    def get_fingerprint(self):
        return self._key.get_fingerprint()

    def get_name(self):
        return self._key.get_name()

    @property
    def type(self):
        return self.get_name()

    def sign_ssh_data(self, data):
        raise NotImplementedError

    def verify_ssh_sig(self, data, msg):
        return self._key.verify_ssh_sig(data, msg)

    def write_private_key(self, file_obj, password=None):
        raise NotImplementedError

    def write_private_key_file(self, filename, password=None):
        raise NotImplementedError

    size = get_bits
    hasPrivate = has_private = can_sign

    def fingerprint(self, hash_function=None):
        return self._key.fingerprint(hash_function=hash_function)

    def hexFingerprint(self, hash_function=None):
        return bhex(self.fingerprint(hash_function=hash_function))

    def publicKey(self):
        return self

    def sign(self, hash):
        raise NotImplementedError("Sign not supported on public keys")

    def verify(self, hash, sign):
        if not self.check_validity():
            return False

        return self._key.verify(hash, sign)

    def check_validity(self):
        now = time.time()
        after, before = self.valid_range
        if after < 0 or now < after:
            return False
        if before < 0 or now >= before:
            return False
        return True

    def encodeSign(self, sign):
        return self._key.encodeSign(sign)

    def decodeSign(self, sign):
        return self._key.decodeSign(sign)

    def exportKey(self, keyformat='PEM'):
        if keyformat == 'OpenSSH':
            return b'%s-cert-v01@openssh.com %s' % (self._key.type, binascii.b2a_base64(self.networkRepresentation()))
        raise NotImplementedError

    def dumpToNetworkStream(self, stream):
        raise NotImplementedError

    @classmethod
    def fromNetworkStream(cls, stream, log=None):
        raise NotImplementedError

    def correctedHash(self, hash):
        return self._key.correctedHash(hash)

    @classmethod
    def _doLoads(cls, data, comment='', options=None, log=None):
        bdata = tobytes(data).strip()
        for key_type in cls._sshPublicKeyPrefixes:
            if bdata.startswith(key_type + b' '):
                parts = bdata.split(b' ', 2)
                key = cls.fromNetworkRepresentation(
                    binascii.a2b_base64(b(parts[1]))
                )
                if len(parts) > 2 and parts[2]:
                    key.comment = parts[2]
                else:
                    key.comment = comment
                key.options = options or {}
                return key
        raise ValueError("Certificate is not supported")

    @classmethod
    def fromNetworkRepresentation(cls, data, log=None):
        stream = StreamBase(bytesio(data))

        key_type = stream.readBEStr()
        if key_type not in cls.key_parsers:
            raise ValueError("Unsupported key type: {}".format(key_type))
        nonce = stream.readBEStr()  # noqa
        key = cls.key_parsers[key_type](stream, log=log)

        serial = _readBEInt64(stream)
        cert_type = CertificateType(stream.readBEInt())
        key_id = stream.readBEStr()
        valid_principals = parse_string_list(stream.readBEStr())
        valid_after = _readBEInt64(stream)
        valid_before = _readBEInt64(stream)
        critical_options = parse_dict(stream.readBEStr())
        extensions = parse_dict(stream.readBEStr())
        reserved = stream.readBEStr()  # noqa

        signing_key = CryptoKey.fromNetworkRepresentation(stream.readBEStr())
        if signing_key is None:
            raise ValueError("Unsupported CA key type")

        signed_block = _get_so_far(stream)
        sig_key_type, signature = parse_signature(stream.readBEStr())
        if sig_key_type != signing_key.type or not signing_key.verify(
            signing_key.correctedHash(signed_block),
            signing_key.decodeSign(signature),
        ):
            raise ValueError("Certificate signature is invalid")

        return cls(
            public_key=key,
            # nonce=nonce,  # do we need this?
            serial=serial,
            cert_type=cert_type,
            key_id=key_id,
            principals=valid_principals,
            valid_range=(valid_after, valid_before),
            critical_options=critical_options,
            extensions=extensions,
            # reserved=reserved,  # do we need this?
            signing_key=signing_key,
            blob=data,
        )

    networkRepresentation = asbytes

    def __str__(self):
        fp_method = 'sha256'

        return '%s [%s %s:%s key_id %r serial %s CA %s %s:%s]' % (
            type(self).__name__,
            self.type,
            fp_method.upper(),
            self.hexFingerprint(fp_method),
            self.key_id,
            self.serial,
            self.signing_key.type,
            fp_method.upper(),
            self.signing_key.hexFingerprint(fp_method),
        )

