"""

    File: decoder_py.py
    Description: 
    
        This module decomposes a raw RTMP stream into AMF packets.  It keeps a buffer of incoming data for each 
        stream and drops complete packets into a buffer.

    Author: Kyle Vogt
    Date  : June 5th, 2007
    Copyright (c) 2007, Justin.tv, Inc.
    
"""

import struct, traceback
from twisted.python import log
from flvtools.rtmp import amf_py as amf
import os, traceback, re

class Decoder():
    
    READ_HEADER = 1
    READ_BODY = 2
    
    CHANNEL_ID =        r'(?P<channel_id>(\x00.{1})|(\x01.{2})|(.{1}))'
    PACKET_TIMECODE =   r'(?P<packet_timecode>(?P<extra_tc>\xFF{3})|(.{3}))'
    PACKET_TYPE =       r'(?P<packet_type>.{1})'
    PACKET_SIZE =       r'(?P<packet_size>.{3})'
    STREAM_ID =         r'(?P<stream_id>.{4})'
    LONG_TIMECODE =     r'(?P<long_tc>(?(extra_tc).{4}))'

    HEADER_0 = re.compile('(?P<header>' + CHANNEL_ID + PACKET_TIMECODE + PACKET_SIZE + PACKET_TYPE + STREAM_ID + LONG_TIMECODE + ')', re.DOTALL)
    HEADER_1 = re.compile('(?P<header>' + CHANNEL_ID + PACKET_TIMECODE + PACKET_SIZE + PACKET_TYPE + LONG_TIMECODE + ')', re.DOTALL)
    HEADER_2 = re.compile('(?P<header>' + CHANNEL_ID + PACKET_TIMECODE + LONG_TIMECODE + ')', re.DOTALL)
    HEADER_3 = re.compile('(?P<header>' + CHANNEL_ID + ')', re.DOTALL)
    
    def __init__(self, max_buffer = 5000000):
        self.amf = amf.Amf()
        self.max_buffer = max_buffer
        self.stream = None
        self.streams = {}
        self.packets = []
        self.events = []
        self.pings = []
        self.bytesRead = 3073
        self.chunkSize = 128
        self.count = 0
        self.validPackets = [0x08, 0x09, 0x12, 0x14, 0x16]
        self.statusPackets = [0x04, 0x05]
        self.eventPackets = [0x12, 0x14]
        self.reset()
        
    def reset(self):
        self.data = ''
        self.pointer = 0
        self.left_in_chunk = 0
        self.mode = self.READ_HEADER
                
    def process(self, data):
        "Process raw RTMP data and parse it into AMF streams"
        # Keep track of total bytes read
        self.data += data
        self.bytesRead += len(data)
        
        # Check to see if data is corrupt and filling up our buffer
        if len(self.data) > self.max_buffer:
            log.msg("Error: decoder scanned %i bytes without finding a packet" % self.max_buffer)
            self.reset()
            return False
        
        while True:
            self.available_bytes = len(self.data) - self.pointer
            # Do we need to read the header?
            if self.mode == self.READ_HEADER:
                # Can we read header length?
                if self.available_bytes < 1:   
                    return True
                # Read it!
                self.stream = self.readHeader()
                if self.stream:
                    self.pointer += self.headerLength
                    self.available_bytes -= self.headerLength
                    self.mode = self.READ_BODY
                else: 
                    return True
            # Do we need to read the body?
            if self.mode == self.READ_BODY:
                # How much is left to go?
                remaining = self.stream['length'] - len(self.stream['data'])
                bytes_to_read = min(self.chunkSize, remaining)
                # Do we have enough bytes right now?
                if self.available_bytes < bytes_to_read:
                    return True
                # Read up to an entire chunk
                self.stream['data'] += self.data[self.pointer : self.pointer + bytes_to_read]
                self.pointer += bytes_to_read
                #self.data = self.data[bytes_to_read:]
                self.available_bytes -= bytes_to_read
                remaining -= bytes_to_read
                # Is the packet complete?
                if not remaining:
                    self.processPacket()
                    if len(self.data) > self.pointer: 
                        self.data = self.data[self.pointer:]
                        self.pointer = 0
                    self.mode = self.READ_HEADER
                # If we read a full chunk, it's time to read a 1-byte header
                if bytes_to_read == self.chunkSize:
                    self.mode = self.READ_HEADER
                # Are we out of data to work with?
                if not self.available_bytes:
                    self.data = ''
                    self.pointer = 0
                    return
                                
    def readHeader(self):
        # TODO: Process high channels and long timecodes
        
        bits = ord(self.data[self.pointer]) >> 6
            
        header_re = getattr(self, 'HEADER_' + str(bits))
        result = header_re.match(self.data[self.pointer : self.pointer + 18])
        # Process dict
        if result:
            d = result.groupdict()
            id = d['channel_id']
            if len(id) > 1:
                choice = {
                    2 : (lambda x : ord(x[1]) + 64),
                    3 : (lambda x : struct.unpack("!H", x[1:3])[0] + 64)
                }
                channel_id = str(choice[len(id)](id))
            else:
                channel_id = ord(id[0]) & 0x3F
            # Look up channel
            if channel_id in self.streams:
                stream = self.streams[channel_id]
            else:
                stream = self.createStream(channel_id)
            # Process header
            if 'packet_timecode' in d:
                stream['timecode'] = struct.unpack("!I", d['long_tc'] or chr(0) + d['packet_timecode'])[0]
                stream['absflag'] = False
            if 'packet_type' in d:
                stream['type'] = ord(d['packet_type'])
            if 'packet_size' in d:
                stream['length'] = struct.unpack("!I", chr(0) + d['packet_size'])[0]
            if 'stream_id' in d:
                stream['id'] = struct.unpack("<I", d['stream_id'])[0]
                stream['absflag'] = True
            self.headerLength = len(d['header'])
            return stream
        else:
            return False
        
    def processPacket(self):
        "Process a complete packet and drop it into the packet buffer"
        result = True
        #log.msg("got a packet of type %s, timecode %s, stream %s (%s bytes)" % (self.stream['type'], self.stream['timecode'], self.stream['id'], len(self.stream['data'])))
        #if len(self.stream['data']) < 1000:
        #    log.msg(repr(self.stream['data']))
        if self.stream['type'] == 0x01: 
            if len(self.stream['data']) >= 4:
                #log.msg('Got Chunk size packet! %s' % self.stream['data'])
                newSize = self.readChunkSize(self.stream['data'])
                if newSize >= 64 and newSize <= 65536: 
                    self.chunkSize = newSize
                    log.msg("Chunk size set to %i" % newSize)
                else: 
                    log.msg("Error: decoder found corrupt data (tried to set chunk size to %s)" % newSize)

        # Useful packes
        elif self.stream['type'] in self.validPackets:
            packet = {'count' : self.count}
            packet.update(self.stream)
            packet['key'] = (packet['type'] == 0x09 and self.isKeyframe(packet))
            self.count += 1
            # Event packet
            if self.stream['type'] in self.eventPackets:
                event = {'id' : self.stream['id'], 'data' : self.amf.decode(self.stream['data'])}
                #log.msg('event: %s' % event)
                self.events.append(event)
            # Regular A/V packet
            else:
                self.packets.append(packet)
        elif self.stream['type'] in self.statusPackets:
            #ping and other status packets
            p = {'type':self.stream['type'], 'data':self.stream['data']}
            self.pings.append(p)
            
        # Random packet we don't care about
        else:
            pass
        # Clean up for next packet
        self.stream['data'] = ''
            
    def createStream(self, objectId):
        "Create and return new stream buffer"
        #log.msg('Created stream %s' % objectId)
        stream = {'data' : '', 'type' : 0, 'length' : 0, 'absflag' : False, 'timecode' : 0, 'id' : 0, 'object' : int(objectId)}
        self.streams[objectId] = stream
        return stream
                        
    def readChunkSize(self, data):
        "Return the new chunk size for data blocks"
        size = struct.unpack("!i", data[:4])[0]    
        return size

    def isKeyframe(self, tag):
        "Boolean fucntion to determine whether the FLV tag in this tag is a keyframe"
        return ((len(tag['data']) > 0) and (tag['type'] == 0x09) and self.getVideoFrame(tag) == 'keyframe')

    def getVideoFrame(self, tag):
        "Determines whether the FLV tag is a keyframe or not"
        frame = (ord(tag['data'][0]) & 0xF0) >> 4
        if frame == 1: return 'keyframe'
        elif frame == 2: return 'iframe'
        elif frame == 3: return 'unsupported'
        else: return 'err'    
