"""

    File: flv.py
    Description: 
    
        This module contains infomation and functions needed to write or stream FLV files.

    Author: Kyle Vogt
    Date  : June 12th, 2007
    Copyright (c) 2007, Justin.tv, Inc.
    
"""

import time, struct, os, sys, random, traceback, tempfile, fcntl
from array import array
import traceback
import StringIO
import subprocess
import circularbuffer as cb
import math
import errno

FLV_AUDIO_TAG = 8
FLV_VIDEO_TAG = 9
FLV_META_TAG  = 18

def _safe_read(file_object, size):
    while True:
        try:
            return file_object.read(size)
        except IOError as e:
            if e.errno != errno.EINTR:
                raise

# I provide this function so that EINTR errors in the future can be fixed for
# write operations too but I'm not going to take the risk to break something
# else than the code I'm concerned about right now.
def _safe_write(file_object, data):
    while True:
        try:
            return file_object.write(data)
        except IOError as e:
            if e.errno != errno.EINTR:
                raise

def writeHeader():
    "Write the 13 byte FLV header"
    # FLV header, followed by the first PreviousTagSize (always 0)
    header = 'FLV\x01\x05\x00\x00\x00\t\x00\x00\x00\x00'
    return header
        
def thumb(input_file_path, output_file_path, offset = 2.0, **kwargs):
    "Use ffmpeg to generate a thumbnail preview from the given keyframe"
    command = "ffmpeg -an -ss %s -i %s -vframes 1 -f image2 %s" % (offset, input_file_path, output_file_path)
    assert subprocess.call(command, shell=True, **kwargs) == 0 and os.access(output_file_path, os.R_OK), "error on generate thumbnail"
    
def writeTag(tag, timecode):
    "Convert a tag into an FLV tag"
    if timecode < 0: timecode = 0
    data = chr(tag['type']) + struct.pack("!I", len(tag['data']))[1:]
    # Write the tag
    tstamp = struct.pack("!I", int(timecode))
    data += str(tstamp[1:])
    data += str(tstamp[0])
    data += chr(0) * 3 + tag['data']
    data += struct.pack("!l", int(len(tag['data']) + 11))
    return data
        
def writeMetaTag(metadata):
    "Wrap raw metadata in an FLV tag and return it"
    data = chr(0x12) + struct.pack("!I", len(metadata))[1:]
    data += chr(0x00) * 4
    data += chr(0) * 3 + metadata
    data += struct.pack("!l", int(len(metadata) + 11))
    return data
    
def readTag(data, offset = 0, add_avc_data = False):
    "Read an FLV tag from a data buffer"
    #print "reading tag at offset: %s"%offset
    try:
        if offset < 0: return False
        if len(data[offset:]) < 8: return False
        
        # mask off the reserved and filter bits from bytes[0]; last five bits contain tagtype
        tag_type = ord(data[offset]) & 31
        # read the length of the packet, 0x0 + bytes[1..3]
        length = struct.unpack("!I", chr(0x00) + data[offset + 1: offset + 4])[0]
        # grab the timestamp, bytes[7] + bytes[4..6]
        time_stamp = struct.unpack("!I", data[offset + 7] + data[offset + 4: offset + 7])[0]

        # skip bytes[8..10] as they are always zero (StreamID); consumed total of 11 bytes
        # next is the Audio/VideoTagHead, followed by optional EncryptionHeader and FilterParams
        if len(data[offset:]) >= 11 + length:
            tag_data = data[offset + 11 : offset + 11 + length]
            tag = {'data' : tag_data, 'type' : tag_type, 'time_stamp' : time_stamp}
            if add_avc_data and tag['type'] == FLV_VIDEO_TAG and getVideoCodec(tag) == 'avc':
                addAVCTagData(tag)
            else:
                tag['has_avc_data'] = False
        else:
            return False
    except:
        print 'Error reading FLV tag'
        return False

    return tag

def readFile(file_object, chunk_size = 8192, avc_info = False):
    if _safe_read(file_object, 13)[:3] != 'FLV':
        print 'Bad FLV header!'
        return
    data = _safe_read(file_object, chunk_size)
    #pos in data array
    pointer = 0
    #number of bytes written
    counter = 0
    while True:
        tag = readTag(data, pointer, add_avc_data = avc_info)
        if tag:
            tag['byte_offset'] = 13 + counter
            counter += 15 + len(tag['data'])
            pointer += 15 + len(tag['data'])
            yield tag
        else:
            newdata = _safe_read(file_object, chunk_size)
            if newdata: 
                data += newdata
                data = data[pointer:]
                pointer = 0
            else: 
                return
                
                
def readFileNoBlock(fd, chunk_size = 8192, timeout=5):
    #set O_NONBLOCK  
    fcntl.fcntl(fd, fcntl.F_SETFL, os.O_NONBLOCK | fcntl.fcntl(fd, fcntl.F_GETFL))
    #spin lock on header read
    header = ""
    no_data_time = None
    while len(header) < 13:
        try:
            d = os.read(fd, 13 - len(header))
            if not d:
                if no_data_time:
                    if time.time() - no_data_time > timeout:
                        raise StopIteration
                    yield False
                else:
                    no_data_time = time.time()
                    yield False
            else:        
                no_data_time = None
                header += d      
                if len(header) < 13:
                    yield False
        except OSError:
            yield False 
    
    if header[:3] != 'FLV': 
        print 'Bad FLV header:%s'%header
        return
        
    data = ""
    #pos in data array
    pointer = 0
    #number of bytes written
    counter = 0
    newdata=''
    while True:
        tag = readTag(data, pointer)
        if tag:
            tag['byte_offset'] = 13 + counter
            counter += 15 + len(tag['data'])
            pointer += 15 + len(tag['data'])
            yield tag
        else:
            try:
                newdata = os.read(fd, chunk_size)     
                if newdata:
                    no_data_time = None
                    data += newdata
                    data = data[pointer:]
                    pointer = 0
                else:
                    if no_data_time:
                        if time.time() - no_data_time > timeout:
                            raise StopIteration
                        yield False
                    else:
                        no_data_time = time.time()
                        yield False
                             
            except OSError:
                yield False
                

def isAudioConfig(tag):
    return tag['type'] == 8 and getAudioCodec(tag) == 'aac' and tag['data'][1] == '\x00'

def isKeyframe(tag):
    "Boolean fucntion to determine whether the FLV tag in this tag is a keyframe"
    return ((len(tag['data']) > 0) and (tag['type'] == 0x09) and getVideoFrame(tag) == 'keyframe')
    
def isVideoConfig(tag):
    "Returns whether the tag is a video config tag or not"
    return tag['type'] == 9 and getVideoCodec(tag) == 'avc' and tag['data'][1] == '\x00'

def getVideoCodec(tag):
    "Determines whether the FLV tag in this tag uses the VP6 or Sorenson codec"
    info = ord(tag['data'][0])
    # Determine Codec type
    codec = (info & 0x0F)
    if codec == 1: return "jpeg"
    if codec == 2: return "sor"
    if codec == 3: return "screen"
    if codec == 4: return "vp6"
    if codec == 5: return "vp6"
    if codec == 6: return "screen"
    if codec == 7: return "avc"
    else: return "err"
    
def getVideoFrame(tag):
    "Determines whether the FLV tag is a keyframe or not"
    frame = (ord(tag['data'][0]) & 0xF0) >> 4
    if frame == 1: return 'keyframe'
    elif frame == 2: return 'interframe'
    elif frame == 3: return 'interframe'
    elif frame == 4: return 'keyframe'
    elif frame == 5: return 'info'
    else: return 'err'
    
def getAudioCodec(tag):
    "Determines the audio codec of the given tag"
    format = (ord(tag['data'][0]) & 0xF0) >> 4
    if   format ==  0: return 'uncompressed'
    elif format ==  1: return 'adpcm'
    elif format ==  2: return 'mp3'
    elif format ==  3: return 'uncompressed_lendian'
    elif format ==  4: return 'nellymoser_16kz_mono'
    elif format ==  5: return 'nellymoser_8kz_mono'
    elif format ==  6: return 'nellymoser'
    elif format == 10: return 'aac'
    else: return 'err'

def getTagType(tag):
    if   tag['type'] == FLV_AUDIO_TAG: return "Audio Tag"
    elif tag['type'] == FLV_VIDEO_TAG: return "Video Tag"
    elif tag['type'] == FLV_META_TAG:  return "Meta Tag"
    else:                              return "err"

def addAVCTagData(tag):
    if getVideoCodec(tag) != "avc":
        print "Non-AVC stream does not have composition times."
        tag['has_avc_data'] = False
        return

    try:
        tag['avc_pkt_type'] = getVideoFrame(tag)  # (ord(tag['data'][0]) & 0xF0) >> 4
        # We need to do sign extension if the values are incorrect and negative
        if (ord(tag['data'][2]) & 0x80) >> 7 == 1:
            tag['cts'] = struct.unpack("!i", chr(0xFF) + tag['data'][2:5])[0]
        else:
            tag['cts'] = (struct.unpack("!I", chr(0x00) + tag['data'][2:5])[0])
        tag['pts'] = tag['cts'] + tag['time_stamp']  # time_stamp = dts
#        print "  reading dts %s cts %s calc pts %s" % (tag['time_stamp'], tag['cts'], tag['pts'])
        tag['has_avc_data'] = True
    except Exception, e:
        print "Error reading composition time from AVC stream, dts %s: %s" % (tag['time_stamp'], e)
        tag['has_avc_data'] = False
        tag['cts'] = tag['avc_pkt_type'] = None
    
def concatinate_flv_files(in_list, out_file):
    video_config_written = False
    audio_config_written = False
    
    out_buff = StringIO.StringIO()
    out_buff_c = 0
    time_offset = 0
    out_buff.write(writeHeader())
    for f in in_list:
        tags = readFile(f)
        for t in tags:
            if isVideoConfig(t):
                if (video_config_written):
                    continue
                video_config_written = True
            if isAudioConfig(t):
                if (audio_config_written):
                    continue
                audio_config_written = True
        
            last_tag = t
            out_buff.write(writeTag(t, t['time_stamp'] + time_offset))
            out_buff_c += 1
            if out_buff_c > 1024:
              out_file.write(out_buff.getvalue())
              out_buff_c = 0
              out_buff = StringIO.StringIO()    
        time_offset += last_tag['time_stamp'] + 10
    
    out_file.write(out_buff.getvalue())
    out_buff_c = 0
    out_buff = StringIO.StringIO()
          
                      
def repair_time_stamps(in_file, out_file, sort = False, sort_buffer_length = 100, shift = False, max_delta = 1500, min_delta = -1500, new_delta = 1, out_buffer_length = 1000):
    #dict is hack to expose vars to nested scopes
    l = {}
    l['out_buff'] = StringIO.StringIO() 
    l['out_buff_c']= 0
    l['sort_buff'] = []
    
    def write_buffered(data):
        if out_buffer_length > 0:
            l['out_buff'].write(data)
            l['out_buff_c'] += 1
            if l['out_buff_c'] > out_buffer_length:
                out_file.write(l['out_buff'].getvalue())
                l['out_buff_c'] = 0
                l['out_buff'] = StringIO.StringIO()
        else:
            out_file.write(data)
    
    def write_tag(tag):
        if sort:
            #insert tag into sorted buffer 
            if not l['sort_buff'] or tag['time_stamp'] >= l['sort_buff'][-1]['time_stamp']:
                l['sort_buff'].append(tag)
            elif tag['time_stamp'] < l['sort_buff'][0]['time_stamp']: 
                l['sort_buff'].insert(0, tag)
            else:
                s = 0
                e = len(l['sort_buff']) -1

                while e-s > 1:
                    m = int((s + e)/2)
                    if tag['time_stamp'] > l['sort_buff'][m]['time_stamp']:
                        s = m
                    else:
                        e = m
                l['sort_buff'].insert(e, tag)
            if len(l['sort_buff']) > sort_buffer_length:
                tag = l['sort_buff'].pop(0)
                write_buffered(writeTag(tag, tag['time_stamp']))
            l['p_tag'] = l['sort_buff'][-1]    
        else:
            l['p_tag'] = tag
            write_buffered(writeTag(tag, tag['time_stamp']))
    
    def flush_buffers():    
        if sort:
            for tag in l['sort_buff']:
                write_buffered(writeTag(tag, tag['time_stamp']))     
        if out_buffer_length > 0:
            out_file.write(l['out_buff'].getvalue())
            l['out_buff_c'] = 0
            l['out_buff'] = StringIO.StringIO()
    
    tags = readFile(in_file, chunk_size=1024)

    write_buffered(writeHeader())

    adjust = 0.0
    t = tags.next()

    if shift:
        adjust = -t['time_stamp']
        t['time_stamp'] = 0
    write_tag(t)
    
    for t in tags:
        t['time_stamp'] += adjust
        diff = t['time_stamp'] - l['p_tag']['time_stamp'] 
        
        if shift and ( diff <= min_delta or diff > max_delta):
            new_adjust = (l['p_tag']['time_stamp'] + new_delta) - t['time_stamp']
            t['time_stamp'] += new_adjust
            adjust += new_adjust
        write_tag(t)
    flush_buffers()

# XXX: Replaced by the ts_filters.py sub-lib; only fixes DTS ordering issues
def normalize_timestamps(istream, ostream, tag_buffer_length = 20, out_buf_length = 16384):
    buf_stream = BufferedWriter(ostream, out_buf_length)
    buf_stream.write(writeHeader())

    # TODO: should arguably be in ts_gen.normalize()
    read_tags = readFile(istream, chunk_size = 1024, avc_info = True)

    ts_gen = TSNormalizationChain(tag_buffer_length)
    ts_gen.push_filter(cts_dts_agreement_filter)

    for tag in ts_gen.normalize(read_tags):
        buf_stream.write(serialize_tag(tag))

    buf_stream.flush()

def print_tag_str(tag):
    return "%s (%s): dts %s, cts %s, size %s" % (getTagType(tag), getVideoFrame(tag), tag['time_stamp'], tag.get('cts', None), len(tag['data']))

def print_timestamps(istream):
    # TODO: should arguably be in ts_gen.normalize()
    read_tags = readFile(istream, chunk_size = 1024, avc_info = True)

    for i, tag in enumerate(read_tags):
        #      #num type: dts, cts size, 
        print "#%s: %s " % (i, print_tag_str(tag))

# shift procedure:
#  * fill tag buffer
#  * change cts on eviction
def shift_cts(istream, ostream, out_buf_length = 16384):
    buf_stream = BufferedWriter(ostream, out_buf_length)
    buf_stream.write(writeHeader())

    # TODO: should arguably be in ts_gen.normalize()
    read_tags = readFile(istream, chunk_size = 1024, avc_info = True)

    # On the size of the TSNormalizationChain circular buffer:
    # For DetectShiftCTSFilter:
    # The length of the normalization chain is the length of our CTS shift 
    # pre-detection.  Longer lengths are more likely to work, but introduce 
    # delay before we start spitting out results.  With audio/video 
    # interleaving, one can expect something like 60 tags in a second.  All
    # the files I've looked at seem to have repeating patterns of negative 
    # CTS values that reach there minimum within this tag window.  As such
    # we go with a one second delay.
    #
    # For DynamicShiftCTSFilter and ConstantShiftCTSFilter:
    # The length of the circular buffer can be 1 if no ShiftAudioDTSFilter
    # is in use.  ShiftAudioDTSFilter requires OrderByDTSFilter to reorder
    # audio tags after rewriting their DTS values.  See the note below for
    # ShiftAudioDTSFilter and it's buffer size requirements.
    ts_gen = TSNormalizationChain(60)
    detector = DetectShiftCTSFilter()
    #detector = DynamicShiftCTSFilter()
    #detector = ConstantShiftCTSFilter(2000)
    ts_gen.push_filter(detector)

    # Adjust DTS timestamps of audio frames by current shift
    ts_gen.push_filter(ShiftAudioDTSFilter(detector))

    # Note, this filter needs a window size large enough to successfully insert
    # all tags less than the shifted audio tags from the ShiftAudioDTSFilter.
    # If the detected shift is large, this can be several hundred tags.  Normal
    # shift distances should be less than 200ms, which is something like 10-20 
    # tags, and the normalization chain above is 60 tags long, which should be 
    # more than enough.
    ts_gen.push_filter(OrderByDTSFilter())

    first = True
    for tag in ts_gen.normalize(read_tags):
        if first:
            first = False
            print "Shift amnt: %s" % detector.shift_amount()

        if has_cts(tag):
            new_cts = tag['cts'] + detector.shift_amount()
            if new_cts < 0:
                print "WARN: shift of %s not enough to compensate for all frames.  Needed %s." % (detector.shift_amount(), abs(tag['cts']))
                new_cts = 0
            tag['cts'] = new_cts
        buf_stream.write(serialize_tag(tag))

    buf_stream.flush()


def trim_to_keyframe(in_filename, out_filename, buffer_length=1024):
    out_buff = StringIO.StringIO() 
    out_buff_c = 0
    out_file = open(out_filename, 'w')
    out_buff.write(writeHeader())
    
    tags = readFile(open(in_filename, 'r'))
    have_key_frame = False
    for t in tags:
        if not have_key_frame:
            if isKeyframe(t):
                have_key_frame = True
        if have_key_frame:
            out_buff.write(writeTag(t, t['time_stamp']))
            out_buff_c += 1
            if out_buff_c > buffer_length:
                out_file.write(out_buff.getvalue())
                out_buff_c = 0
                out_buff = StringIO.StringIO()
            
    out_file.write(out_buff.getvalue())
    out_buff_c = 0
    out_buff = StringIO.StringIO()
    

def get_keyframes(filename):
    """
        Return a list of all keyframes in an flvfile.    
    """    
    tags = readFile(open(filename, 'r'))
    key_frames = []
    for t in tags: 
        if isKeyframe(t):
             del t['data']
             key_frames.append(t) 
    return key_frames

