# cython: language_level=3

from libc.stdint cimport uint64_t, uint32_t
from libcpp cimport bool

import struct


MAGIC = 0x5353484b524c0a00ULL  # "SSHKRL\n\0"
FORMAT_VERSION = 1
cdef HEADER_STRUCT = struct.Struct('>QLQQQ')
cdef STRING_LENGTH_STRUCT = struct.Struct('>L')
cdef SERIAL_STRUCT = struct.Struct('>Q')
cdef SERIAL_RANGE_STRUCT = struct.Struct('>QQ')


cdef read_be_string(const unsigned char[:] data):
    view = memoryview(data)

    len_bytes = view[:4]
    if len(len_bytes) < 4:
        raise ValueError("Insufficient data to parse string")

    s_len, = STRING_LENGTH_STRUCT.unpack(len_bytes)
    view = view[4:]
    s = view[:s_len]
    if len(s) < s_len:
        raise ValueError("Insufficient data to parse string")
    return s, view[s_len:]


cdef read_enum(byte, constructor):
    if isinstance(byte, int):
        # python 3
        return constructor(byte)
    else:
        # python 2
        return constructor(ord(byte))


# struct Header:
#     uint64 magic
#     uint32 format_version
#     uint64 krl_version
#     uint64 generated_date
#     uint64 flags
#     string reserved
#     string comment
cdef class Header:
    def __init__(
        self,
        uint64_t magic,
        uint32_t format_version,
        uint64_t krl_version,
        uint64_t generated_date,
        uint64_t flags,
        bytes reserved,
        bytes comment,
    ):
        self.magic = magic
        self.format_version = format_version
        self.krl_version = krl_version
        self.generated_date = generated_date
        self.flags = flags
        self.reserved = reserved
        self.comment = comment


cdef class Section:
    def __init__(self, SectionType section_type):
        self.section_type = section_type


cdef class CertificateSubsection:
    def __init__(self, CertificateSectionType section_type):
        self.section_type = section_type

    cpdef bool cert_valid(self, uint64_t serial, bytes key_id):
        return True


cdef class CertificatesSectionSerialList(CertificateSubsection):
    def __init__(self, list serials):
        super().__init__(section_type=CertificateSectionType.SerialList)
        self.serials = serials

    cpdef bool cert_valid(self, uint64_t serial, bytes key_id):
        return serial not in self.serials

    def __repr__(self):
        return "CertificatesSectionSerialList[{!r}]".format(self.serials)


cdef class CertificatesSectionSerialRange(CertificateSubsection):
    def __init__(self, uint64_t serial_min, uint64_t serial_max):
        super().__init__(section_type=CertificateSectionType.SerialRange)
        self.serial_min = serial_min
        self.serial_max = serial_max

    cpdef bool cert_valid(self, uint64_t serial, bytes key_id):
        return serial < self.serial_min or serial > self.serial_max

    def __repr__(self):
        return "CertificatesSectionSerialRange({}, {})".format(self.serial_min, self.serial_max)


cdef class CertificatesSectionSerialBitmap(CertificateSubsection):
    def __init__(self, uint64_t serial_offset, bytes revoked_keys_bitmap):
        super().__init__(section_type=CertificateSectionType.SerialBitmap)
        self.serial_offset = serial_offset
        self.revoked_keys_bitmap = revoked_keys_bitmap

    cpdef bool cert_valid(self, uint64_t serial, bytes key_id):
        if serial < self.serial_offset or serial >= self.serial_offset + len(self.revoked_keys_bitmap) * 8:
            return True

        cdef const unsigned char[:] bitmap = self.revoked_keys_bitmap
        cdef int byte_idx = (serial - self.serial_offset) // 8
        cdef int bit_idx = 7 - (serial - self.serial_offset) % 8
        cdef unsigned char byte = bitmap[byte_idx]
        return (byte >> bit_idx) & 1 != 1

    def __repr__(self):
        return "CertificatesSectionSerialBitmap({}, {!r})".format(self.serial_offset, self.revoked_keys_bitmap)


cdef class CertificatesSectionKeyId(CertificateSubsection):
    def __init__(self, list key_ids):
        super().__init__(section_type=CertificateSectionType.KeyId)
        self.key_ids = key_ids

    cpdef bool cert_valid(self, uint64_t serial, bytes key_id):
        return key_id not in self.key_ids

    def __repr__(self):
        return "CertificatesSectionKeyId({})".format(", ".join(repr(key_id) for key_id in self.key_ids))


cdef class CertificatesSection(Section):
    def __init__(self, bytes ca_key, bytes reserved, list cert_sections):
        super().__init__(section_type=SectionType.Certificates)
        self.ca_key = ca_key
        self.reserved = reserved
        self.cert_sections = cert_sections

    cpdef bool cert_valid(self, bytes ca_key, uint64_t serial, bytes key_id):
        if ca_key != self.ca_key:
            return True

        for section in self.cert_sections:
            if not section.cert_valid(serial, key_id):
                return False

        return True

    def __repr__(self):
        return "CertificatesSection(ca_key={!r}, reserved={!r}, cert_sections={!r})".format(
            self.ca_key,
            self.reserved,
            self.cert_sections,
        )


cdef class ExplicitKeySection(Section):
    def __init__(self, list public_key_blobs):
        super().__init__(section_type=SectionType.ExplicitKey)
        self.public_key_blobs = public_key_blobs

    cpdef bool key_valid(self, object key):
        # TODO for now we just ignore plain keys to do not mess with
        # key serialization. Anyway, plain keys ain't used in our KRL.
        return True


cdef class FingerprintSHA1Section(Section):
    def __init__(self, list public_key_hashes):
        super().__init__(section_type=SectionType.FingerprintSHA1)
        self.public_key_hashes = public_key_hashes

    cpdef bool key_valid(self, object key):
        return key.fingerprint('sha1') not in self.public_key_hashes


cdef class FingerprintSHA256Section(Section):
    def __init__(self, list public_key_hashes):
        super().__init__(section_type=SectionType.FingerprintSHA256)
        self.public_key_hashes = public_key_hashes

    cpdef bool key_valid(self, object key):
        return key.fingerprint('sha256') not in self.public_key_hashes


cdef class SignatureSection(Section):
    def __init__(self, bytes signature_key, bytes signature, bytes blob):
        super().__init__(section_type=SectionType.Signature)
        self.signature_key = signature_key
        self.signature = signature
        self.blob = blob


cdef class KRL:
    def __init__(self, Header header, list sections, SignatureSection signature = None):
        self.header = header
        self.sections = sections
        self.signature = signature

    cpdef bool cert_valid(self, object cert):
        cdef uint64_t serial = cert.serial
        cdef bytes key_id = cert.key_id
        cdef bytes ca_key = cert.signing_key.networkRepresentation()

        for section in self.sections:
            if isinstance(section, CertificatesSection):
                if not section.cert_valid(ca_key, serial, key_id):
                    return False
            elif not section.key_valid(cert):
                return False

        return True


cpdef load_header(const unsigned char[:] data) except +:
    cdef view = memoryview(data)

    ints = view[:36]
    if len(ints) < 36:
        raise ValueError("Insufficient data to parse KRL header")

    magic, format_version, krl_version, generated_date, flags = HEADER_STRUCT.unpack(ints)
    if magic != MAGIC:
        raise ValueError("KRL header magic mismatch")

    if format_version != FORMAT_VERSION:
        raise ValueError(r"Unsupported KRL format version {format_version}")

    view = view[36:]

    reserved, view = read_be_string(view)
    comment, view = read_be_string(view)

    return Header(
        magic,
        format_version,
        krl_version,
        generated_date,
        flags,
        reserved.tobytes(),
        comment.tobytes(),
    ), view


cpdef load_section(const unsigned char[:] data) except +:
    cdef view = memoryview(data)
    cdef SectionType section_type = read_enum(view[0], SectionType)
    section_data, view = read_be_string(view[1:])

    if section_type == SectionType.Certificates:
        ca_key, section_data = read_be_string(section_data)
        reserved, section_data = read_be_string(section_data)
        subsections = []

        while section_data:
            cert_section_type = read_enum(section_data[0], CertificateSectionType)
            subsection_data, section_data = read_be_string(section_data[1:])

            if cert_section_type == CertificateSectionType.SerialList:
                serials = []
                while len(subsection_data) >= 8:
                    serial, = SERIAL_STRUCT.unpack(subsection_data[:8])
                    serials.append(serial)
                    subsection_data = subsection_data[8:]
                if len(subsection_data):
                    raise ValueError(
                        "Trailing {} bytes in Certificates SerialList section".format(
                            len(subsection_data),
                        )
                    )
                subsections.append(CertificatesSectionSerialList(serials))
            elif cert_section_type == CertificateSectionType.SerialRange:
                if len(subsection_data) != 16:
                    raise ValueError("Certificates SerialRange section has invalid size")
                serial_min, serial_max = SERIAL_RANGE_STRUCT.unpack(subsection_data)
                subsections.append(CertificatesSectionSerialRange(serial_min, serial_max))
            elif cert_section_type == CertificateSectionType.SerialBitmap:
                # TODO
                raise NotImplementedError("Certificates SerialBitmap section is not supported")
            elif cert_section_type == CertificateSectionType.KeyId:
                key_ids = []
                while len(subsection_data) > 4:
                    key_id, subsection_data = read_be_string(subsection_data)
                    key_ids.append(key_id.tobytes())
                if len(subsection_data):
                    raise ValueError(
                        "Trailing {} bytes in Ceritificates KeyId section".format(
                            len(subsection_data),
                    )
                )
                subsections.append(CertificatesSectionKeyId(key_ids))
        return CertificatesSection(ca_key.tobytes(), reserved.tobytes(), subsections), view

    elif section_type in (
        SectionType.ExplicitKey,
        SectionType.FingerprintSHA1,
        SectionType.FingerprintSHA256,
    ):
        constructor = None
        if section_type == SectionType.ExplicitKey:
            constructor = ExplicitKeySection
        elif section_type == SectionType.FingerprintSHA1:
            constructor = FingerprintSHA1Section
        elif section_type == SectionType.FingerprintSHA256:
            constructor = FingerprintSHA256

        blobs = []
        while len(section_data):
            blob, section_data = read_be_string(section_data)
            blobs.append(blob.tobytes())
        return constructor(blobs), view
    elif section_type == SectionType.Signature:
        # Signature has special structure in OpenSSH.
        # It is not nested structure, but flat one
        signature_key = section_data
        signature, view = read_be_string(view)
        return SignatureSection(signature_key.tobytes(), signature.tobytes(), None), view


cpdef loads(const unsigned char[:] data) except +:
    cdef view = memoryview(data)
    cdef Header header
    cdef list sections = []
    cdef SignatureSection signature = None

    header, view = load_header(view)
    while len(view):
        section, view = load_section(view)
        if isinstance(section, SignatureSection):
            # TODO fill the section.blob with data prior to signature itself
            signature = section
        else:
            sections.append(section)

    return KRL(header, sections, signature)
