import enum
import logging

from PyCRC.CRC16 import CRC16

from .util.parsing import BufferReader, StreamParser


LOGGER = logging.getLogger(__name__)


class ChecksumMismatchError(Exception):
    pass


class InvalidPacketError(Exception):

    def __init__(self, *args, consumed_bytes, **kwargs):
        """consumed_bytes are bytes that caused InvalidPacketError but may be valid by itself"""
        super().__init__(*args, **kwargs)
        self.consumed_bytes = consumed_bytes


class WialonCombinePacketType(enum.Enum):
    LOGIN = 'login'
    DATA = 'data'
    KEEP_ALIVE = 'keep_alive'


class WialonCombineGenericPacket(object):

    __slots__ = ['raw', 'type', 'seq', 'len', 'data', 'crc16']

    _crc16 = CRC16()

    def __init__(self, raw, type_, seq, len_, data, crc16):
        self.raw = raw
        self.type = type_
        self.seq = seq
        self.len = len_
        self.data = data
        self.crc16 = crc16

    @classmethod
    def try_parse_from_buffer(cls, buf):
        while True:
            try:
                packet = WialonCombinePacketParser().parse(buf)
            except InvalidPacketError as e:
                buf.appendleft(e.consumed_bytes)
                continue
            break

        if packet is not None and packet.type is not WialonCombinePacketType.KEEP_ALIVE:
            packet.check_crc16()

        return packet

    def check_crc16(self):
        payload = self.raw[:-2]  # The last 2 bytes are the checksum itself.
        crc16 = self._crc16.calculate(payload)
        if crc16 != self.crc16:
            LOGGER.warning('Checksum mismatch: %s != %s', crc16, self.crc16)
            raise ChecksumMismatchError


class WialonCombinePacketParser(object):

    PACKET_START_MARKER = b'$$'
    PACKET_TYPES = {
        0: WialonCombinePacketType.LOGIN,
        1: WialonCombinePacketType.DATA,
        2: WialonCombinePacketType.KEEP_ALIVE,
    }

    def parse(self, buf):
        reader = BufferReader(buf)

        if len(reader) < len(self.PACKET_START_MARKER):
            return None

        try:
            with reader.iterator() as it:
                sp = StreamParser(it)

                self._skip_to_packet_start(sp)

                type_ = self._parse_type(sp)
                seq = self._parse_seq(sp)

                if type_ is WialonCombinePacketType.KEEP_ALIVE:
                    len_ = None
                    data = None
                    crc = None
                else:
                    len_ = self._parse_len(sp)
                    data = self._parse_data(sp, len_)
                    crc = self._parse_crc(sp)

        except StopIteration:
            # Current buffer contains incomplete packet.
            # Return the accumulated data back to the buffer and wait for more data.
            if sp.accumulated_data:
                buf.appendleft(sp.accumulated_data)
            return None

        return WialonCombineGenericPacket(
            raw=bytes(sp.accumulated_data),
            type_=type_,
            seq=seq,
            len_=len_,
            data=data,
            crc16=crc,
        )

    def _skip_to_packet_start(self, sp):
        """Forward the iterator the the first packet start marker"""

        prefix = bytearray()
        skipped = bytearray()

        while True:
            try:
                byte = sp.parse_byte(accumulate=False)
            except StopIteration:
                break

            if len(prefix) >= len(self.PACKET_START_MARKER):
                skipped_byte = prefix.pop(0)
                skipped.append(skipped_byte)

            prefix.append(byte)

            if prefix == self.PACKET_START_MARKER:
                # Packet start marker found.
                sp.accumulated_data.extend(prefix)
                break

        if skipped:
            LOGGER.info('skipped chunk: %s', repr(bytes(skipped)))

    def _parse_type(self, sp):
        type_byte = sp.parse_byte()
        type_ = self.PACKET_TYPES.get(type_byte)

        if type_ is None:
            LOGGER.error('Unknown packet type: %d', type_byte)
            raise InvalidPacketError(consumed_bytes=bytearray([type_byte]))

        return type_

    def _parse_seq(self, sp):
        seq = sp.parse_short()
        return seq

    def _parse_len(self, sp):
        return sp.parse_extensible_short()

    def _parse_data(self, sp, len_):
        data = bytearray()
        for _ in range(len_):
            data.append(sp.parse_byte())
        data = bytes(data)
        return data

    def _parse_crc(self, sp):
        return sp.parse_short()
