import struct
import flv
import circularbuffer as cb
import sys
from time import gmtime, strftime

def has_cts(tag):
    return (tag['type'] == flv.FLV_VIDEO_TAG   and
            flv.getVideoCodec(tag) == 'avc'    and
            tag['has_avc_data'] == True)

# TODO: New, updated serialization code.  Should likely replace writeTag in fly.py
def serialize_avc_video_tag(tag):
    # Copy the FrameType, CodecID, and AVCPacketType
    data = tag['data'][0:2]

    # Write new CTS
    data += struct.pack("!I", int(tag['cts']))[1:4]

    # write the rest of the data
    data += tag['data'][5:]

    return data

#x = 0
def serialize_tag(tag):
    '''Write tag from current-metadata, existing tag data'''

#    global x
#    print '#%s %s: dts %s, len %s' % (x, tag['type'], tag['time_stamp'], len(tag['data']))
#    x += 1

    # Tag header: tag_type + tag_length (not inc. header; 24-bits)
    # TODO: note that we are writting tag['type'] as a byte rather than 
    # 5-bits, such that it also covers the reserved and filter bits, assuming
    # they are zero.  this is wrong!
    data = chr(tag['type']) + struct.pack("!I", len(tag['data']))[1:]

    # Write DTS: lower 24-bits first, then higher 8-bits
    tstamp = struct.pack("!I", int(tag['time_stamp']))
    data += str(tstamp[1:])
    data += str(tstamp[0])

    # StreamID, which is always 3 bytes of 0-bits
    data += chr(0) * 3

    # Write the rest of the tag headers/data
    if has_cts(tag):
        inner_data = serialize_avc_video_tag(tag)
    else:
        inner_data = tag['data']
    data += inner_data

    # Write the useless (but mandatory) tailing tag size
    data += struct.pack("!l", int(len(inner_data) + 11))
    return data

class TSNormalizationChain(object):
    def __init__(self, tag_buffer_length = 60):
        self.probe_filters = list()
        self.input_filters = list()
        self.output_filters = list()
        self.chain_filters = list()

        self.probe_buf = list()
        self.tag_buf = cb.CB(tag_buffer_length)

        self.num_tags = 0

    def push_probe_filter(self, ts_filter):
        self.probe_filters.append(ts_filter)

    def push_input_filter(self, ts_filter):
        self.input_filters.append(ts_filter)

    def push_output_filter(self, ts_filter):
        self.output_filters.append(ts_filter)

    def push_chain_filter(self, ts_filter):
        self.chain_filters.append(ts_filter)

    def probe(self, tag_iterable):
        if len(self.probe_filters) == 0:
            return

        for x in xrange(0, self.tag_buf.buffer_size()):
            tag = tag_iterable.next()
            self.probe_buf.append(tag)
            for filt in self.probe_filters:
                filt.perform(tag)

    def normalize(self, tag_iterable):
        self.probe(tag_iterable)

        # add probed tags to circular buffer
        for new_tag in self.probe_buf:
            tail_tag = self.handle_tag(new_tag)
            if tail_tag: yield tail_tag
        self.probe_buf = list()

        # iterate over the rest of the stream
        for new_tag in tag_iterable:
            tail_tag = self.handle_tag(new_tag)
            if tail_tag: yield tail_tag

        # need to flush the circular buffer
        for idx in xrange(0, len(self.tag_buf)):
            yield self.tag_buf[idx]

    def handle_tag(self, new_tag):
        self.num_tags += 1

        tail_tag = None
        if self.tag_buf.filled():
            tail_tag = self.tag_buf.tail()
            # run output filters on to be evicted tag
            for filt in self.output_filters:
                filt.perform(tail_tag)
#            yield tail_tag

        # run input filters on new tag
        for filt in self.input_filters:
            filt.perform(new_tag)

        # insert new tag in circular buffer
        self.tag_buf.push(new_tag)

        # run chain filters on updated circular buffer
        for filt in self.chain_filters:
            filt.perform(self.tag_buf)

        return tail_tag


#############################################################################
# FILTERS
#
# Four types of filters: probe, input, output, and chain.  Input and output
# filters deal with single frames, right after they are read, or immediately
# before they are serialized.  Chain filters have access to the shared 
# circular buffer of tags, allowing them to make more complex calculations,
# reorder, or even delete tags.  Probe filters only run on the beginning of 
# the file, until the TSNormalizatationChain is full.


# Probe/input filter to detect the needed CTS shift for a file.  Can be used
# in probing mode (max_tag_seq < 0) or continuous/input mode (max_tag_seq = 
# CONTINUOUS_SEEK).  Note that if you want to use it as a probing filter,
# max_tag_seq should be less than or equal to the TSNormalizationChain's
# circular buffer size, else the probe will not complete before 
# serialization begins.
class DetectCTSShiftFilter(object):
    CONTINUOUS_SEEK = -1

    def __init__(self, max_tag_seq = 60):
        self.max_shift = 0
        self.max_tag_seq = max_tag_seq
        self.tags_seen = 0

    def perform(self, tag):
        if not has_cts(tag) \
            or (self.max_tag_seq != DetectCTSShiftFilter.CONTINUOUS_SEEK and self.tags_seen > self.max_tag_seq):
            return

        self.max_tag_seq += 1

        if tag['cts'] < 0:
            cur = self.max_shift
            self.max_shift = max(self.max_shift, abs(tag['cts']))
            if cur != self.max_shift:
                sys.stderr.write("New shift amnt %s at DTS %s.\n" % (self.max_shift, tag['time_stamp']))

    def shift_amount(self):
        return self.max_shift

class DummyShiftFilter(object):
    def __init__(self, shift_amnt):
        self.shift = shift_amnt

    def perform(self):
        pass

    def shift_amount(self):
        return self.shift

# Input/output filter that applies a CTS shift.  shift_obj must handle the 
# method .shift_amount which returns an integer.  Only applies to AVC video
# tags
class ShiftCTSFilter(object):
    def __init__(self, shift_obj):
        self.shift_obj = shift_obj

    def perform(self, tag):
        if not has_cts(tag):
            return

        tag['cts'] += int(self.shift_obj.shift_amount())
        if tag['cts'] < 0:
            sys.stderr.write("Shift insufficient at DTS %s; CTS negative.\n" % tag['time_stamp'])
            tag['cts'] = 0

# Input/output filter that applies a DTS shift to tags
class ShiftDTSFilter(object):
    ANY_TAG = -1

    def __init__(self, shift_obj, tag_type = ANY_TAG):
        self.shift_obj = shift_obj
        self.tag_type = tag_type

    def perform(self, tag):
        if not (self.tag_type == ShiftDTSFilter.ANY_TAG or tag['type'] == self.tag_type):
            return

        tag['time_stamp'] += self.shift_obj.shift_amount()

# Chain filter that reorders tags by DTS
class OrderByDTSFilter(object):
    def __init__(self):
        # let's keep some stats
        self.reorders = 0
        self.front_inserts = 0
        self.largest_front_insert = 0

    # insertion sort; should be relatively cheap as long as everything is in order
    def perform(self, tag_buf):
        if tag_buf.num_pushed() < 2:
            return 

        new_tag = tag_buf.head()
        for i, tag in enumerate(tag_buf):
            if tag['time_stamp'] > new_tag['time_stamp']:
                if i == 0:
                    self.handle_front_insert(new_tag, tag)

                for j in xrange(i, len(tag_buf)):
                    tag_buf[j], new_tag = new_tag, tag_buf[j]
                    self.reorders += 1

                break

    # XXX: we should drop instead!
    def handle_front_insert(self, new_tag, old_tag):
        # we are at the begining of the tag window - may not be
        # able to reorder enough!
        self.front_inserts += 1
        ts_jump_size = old_tag['time_stamp'] - new_tag['time_stamp']
        self.largest_front_insert = max(self.largest_front_insert, ts_jump_size)
        
        sys.stderr.write("Inserting into front of cb! Cur DTS is %s; front of cb DTS is %s (%s).\n" % (new_tag['time_stamp'], old_tag['time_stamp'], ts_jump_size))

        # protect ourselves from a massive ts jump in the stream;
        # One OO frame is not the end of the world.
        #new_tag['time_stamp'] = old_tag['time_stamp'] - 1
        #if has_cts(new_tag):
        #    new_tag['cts'] = 0

# Input/output filter to detect the maximum dts jump
class GapDetector(object):
    def __init__(self, stage_name=''):
        self.last_tag_dts = 0
        self.max_gap = 0
        self.max_gap_dts = 0
        self.name = stage_name

    def perform(self, tag):
        gap = tag['time_stamp'] - self.last_tag_dts
        if gap > self.max_gap:
            time = strftime("%Y-%m-%d %H:%M:%S", gmtime())
            sys.stderr.write("%s: new maximum ts gap at DTS %s of %s. %s.\n" % (self.name, tag['time_stamp'], gap, time))
            self.max_gap = gap
            self.max_gap_dts = tag['time_stamp']
        self.last_tag_dts = tag['time_stamp']

# debug util
def print_hex(data):
    s = "0x"
    for d in data:
        hex_s = hex(ord(d))[2:]
        hex_s = ((2 - len(hex_s)) * "0") + hex_s
        s += hex_s
    print s

