"""

    File: meta.py
    Description: 
    
        Module to hold info about streams and nodes in the cluster.

    Author: Kyle Vogt
    Date  : December 22nd, 2007
    Copyright (c) 2007, Justin.tv, Inc.
    
"""

from twisted.python import log
from twisted.internet import defer, reactor
from flvtools.rtmp import amf_py as amf
from flvtools import flv, utils
import flvtools.factories
import time, traceback

class MetaObject:
    
    def __init__(self, *kw, **kwargs):
        self.metaKeys = ['name', 'type']
        self.periodicMetaKeys = ['name', 'type']
        self.updatableKeys = []
        self.indexes = []
        self.type = self.__class__.__name__.lower()
        self.name = 'no_name'
        if len(kw):
            self.__dict__.update(kw[0])
        for arg in kwargs: setattr(self, arg, kwargs[arg])
        self.setup()
        
    def setDefaults(self, defaults):
        "Apply defaults to the MetaObject if they do not yet exist"
        for key, val in defaults.items():
            self.__dict__.setdefault(key, val)

    def setup(self):
        "Remaining MetaObject setup calls"
        pass
        
    def update(self, o, sender, omit = []):
        "Update the contents of this object with the contents of a dict"
        for key in [k for k in o if k in self.periodicMetaKeys+self.updatableKeys and k not in omit]:
            if isinstance(self.__dict__[key], dict):
                self.__dict__[key].update(o[key])
                for k,v in o[key].iteritems():
                    if v is None: del self.__dict__[key][k]
            else:
                self.__dict__[key] = o[key]
        
    def getSomeMeta(self, keys):
        "Return a dict containing select meta keys"
        return dict(zip(keys, [getattr(self, key, None) for key in keys]))

    def getMeta(self):
        "Return a dict containing meta info"
        meta = self.getSomeMeta(self.metaKeys)
        meta['server_time'] = int(time.time())
        return meta

    def getPeriodicMeta(self):
        "Return a dict containing meta info"
        return self.getSomeMeta(self.periodicMetaKeys)
        
    def __repr__(self):
        return '<%s %s>' % (self.type, self.name)

# Stream objects are created in xmlrpc.py when wowza requests authorization of a stream
class Stream(MetaObject):
    
    def setup(self):
        # metaKeys are what show up in live.justin.tv under /streams/all.xml and what are propagated 
        # to from the origin server to other servers that pick up the stream
        self.metaKeys.extend([
            'origin', 
            'up_time', 
            'video_codec', 
            'embed_count',
            'embed_ratio',
            'stream_count',
            'nodes',
            'node_names',
            'publisher_buffer',
            'password',
            'embed_enabled',
            'play_as',
            'qscore',
            'language',
            'category',
            'subcategory',
            'killed',
            'bitrate',
            'video_width',
            'video_height',
        ])
        # updatableKeys can be assigned to on the origin at any time, and then are propagated out
        self.updatableKeys.extend([
            'play_as',
            'video_codec',
            'video_width',
            'video_height',
        ])
        # periodicMetaKeys are propagated automatically from the origin node every N seconds
        self.periodicMetaKeys.extend([
            'publisher_buffer',
            'local_count',
            'local_embed_count',
            'file_name',
            'connection_count',
            'play_as',
            'killed',
            'bitrate',
        ])
        self.setDefaults({
            'video_codec' : '?',
            'up_time' : int(time.time()),
            'embed_count' : 0,
            'stream_count' : 0,
            'embed_ratio' : 0,
            'file_name' : None,
            'origin' : None,
            'nodes' : [],
            'node_names' : [],
            'publisher_buffer' : 0,
            'local_count' : 0,
            'local_embed_count' : 0,
            'core' : None,
            'local' : False,
            'password' : None,
            'embed_enabled' : True,
            'language' : None,
            'category' : None,
            'subcategory' : None,
            'killed' : False,
            'bitrate' : -1.0,
            'video_width' : '?',
            'video_height' : '?',
        })
        self.meta = {}
        self.user_id = None
        self.hosts = {}
        self.abs_times = {}
        self.counts = {}
        self.embed_counts = {}
        self.objects = []
        self.last_key = None
        self.subscribers = []
        self.queue = []
        self.data_start_time = None
        self.stream_start_time = None
        self.clientBuffers = None
        self.replicated = False
        self.play_as = {}
        self.conn_count = {}
        self.qscore = 0
        self.connection_count = 0
        self.last_update_time = None
        self.sc_integral = 0

    def update(self, o, sender):
        if 'play_as' in o:
            pass
            #log.msg("Got play_as update from %s: %s" % (sender, o))
        MetaObject.update(self, o, sender, omit=['local_count','connection_count'])
        if 'local_embed_count' in o:
            self.embed_counts[sender] = o.get('local_embed_count', 0)
            self.updateCounts()
        if 'local_count' in o:
            self.counts[sender] = o.get('local_count', 0)
            self.updateCounts()
        if 'connection_count' in o:
            self.conn_count[sender] = int(o.get('connection_count', 0))
            self.updateCounts()
    def updateCounts(self):
        "Update viewer counts"
        count = sum(self.hosts.values())
        if self.core:
            self.counts[self.core.name] = count
            self.core.update_counts()
        self.local_count = count
        self.local_embed_count = self.hosts.get('embed', 0)
        self.embed_count = sum(self.embed_counts.values())
        self.stream_count = sum(self.counts.values())
        local_streams = sum(v for k,v in self.counts.iteritems() if k in self.embed_counts)
        embed_streams = sum(v for k,v in self.embed_counts.iteritems() if k in self.counts)
        if local_streams:
            self.embed_ratio = float(embed_streams)/local_streams
        if self.last_update_time:
            self.sc_integral += \
                sum(v for k,v in self.counts.iteritems() if k in self.core.nodes and 
                                                            self.core.nodes[k].location == 'sanfrancisco') \
                * (time.time()-self.last_update_time)
        self.last_update_time = time.time()
        if sum(self.conn_count.values()):
            self.qscore = self.sc_integral / sum(self.conn_count.values())
    def add_node(self, node):
        "Add a node for this stream"
        if node.name not in self.node_names:
            self.node_names.append(node.name)
            self.nodes.append(':'.join([node.public_ip, str(node.rtmp_port)]))

    def remove_node(self, node):
        "Remove a node for this stream"
        try:
            del self.counts[node.name]
        except:
            pass
        try:
            i = self.node_names.index(node.name)
            self.node_names.pop(i)
            self.nodes.pop(i)
        except:
            pass
        try:
            del self.play_as[node.name]
        except:
            pass
            
    def addPackets(self, packets):
        "Add several packets to the stream buffer"
        data = ''
        for packet in packets:
            self.checkTime(packet)
            self.checkFrame(packet)
        st = time.time()    
        for subscriber in self.subscribers:
            subscriber.process(self, packets)
        et = time.time() - st
        #if self.subscribers:
        #    log.msg('Added %s packets to %s clients in %.3fs (%.3fs per client)' % (len(packets), len(self.subscribers), et, et / len(self.subscribers)))
        if len(packets):
            self.checkLag(packets[-1])
    
    def checkTime(self, packet):
        "Check for absolute timecodes"
        if packet['absflag'] or packet['object'] not in self.abs_times: 
            self.abs_times[packet['object']] = packet['timecode']
            if len(self.abs_times) > 1: self.interleaved = False
            #packet['absflag'] = True
        else:
            self.abs_times[packet['object']] += packet['timecode']
        packet['abstime'] = self.abs_times[packet['object']]
        packet['streamtime'] = time.time() - self.up_time

    def checkFrame(self, packet):
        "If this packet is a keyframe, store meta info about it"
        if packet['key']: 
            self.last_key = packet
            if self.video_codec == '?': 
                self.video_codec = flv.getVideoCodec(packet)

    def subscribe(self, obj):
        "Writes the stream to some writeable object"
        self.subscribers.append(obj)

    def unsubscribe(self, obj):
        "Removes a subscription"
        obj.process(self, [], stop=True)
        if obj in self.subscribers:
            self.subscribers.remove(obj)
        
    def shutdown(self):
        "Remove all subscribers"
        while len(self.subscribers):
            self.unsubscribe(self.subscribers.pop())
        self.hosts = {}
        self.updateCounts()
        try:
            self.clientDeferred.cancel()
        except:
            pass
        
    def addViewer(self, host):
        "Add a viewer"
        count = self.hosts.get(host, 0)
        self.hosts[host] = count + 1
        self.connection_count += 1
        self.updateCounts()
        
    def removeViewer(self, host):
        "Remove a viewer"
        if host in self.hosts:
            if self.hosts[host] == 1:
                del self.hosts[host]
            else:
                self.hosts[host] -= 1
        self.updateCounts()
    
    def measureClientBuffer(self):
        self.clientDeferred = defer.Deferred()
        self.clientBuffers = []
        for subscriber in self.subscribers:
            if hasattr(subscriber, "nc"):
                obj = {    'level' : 'status',
                        'code' : 'NetStream.ClientBuffer.Query' }
                subscriber.nc.invoke("onStatus", 0.0, obj, subscriber.streamId, channel=amf.EVENT_CHANNEL)
        reactor.callLater(1, self.gotClientBuffers)
        return self.clientDeferred
    
    def gotClientBuffers(self):            
        self.clientDeferred.callback(self.clientBufferStats())
    
    def clientBufferStats(self):
        if self.clientBuffers:
            return [func(self.clientBuffers) for func in [utils.avg, min, max]]
        else:
            return []
    
    def checkLag(self, packet):
        "Check for stream lag"
        if not self.data_start_time:
            self.data_start_time = time.time()
            self.stream_start_time = packet['abstime']
        real = time.time() - self.data_start_time
        fake = (packet['abstime'] - self.stream_start_time) / 1000.0
        delay = abs(fake - real)
        self.publisher_buffer = "%.2f" % delay
        #log.msg("Stream delay is %f seconds" % (delay))
        
class Node(MetaObject):
    
    def setup(self):
        self.metaKeys.extend([
            'public_ip', 
            'private_ip', 
            'rtmp_port',
            'streams', 
            'up_time',
            'connections',
            'stream_count',
            'cpu_usage',
            'load',
            'mem_usage',
            'max_disk',
            'rtmp_publish_enabled',
            'rtmp_play_enabled',
            'rtmp_replication_enabled',
            'bytes_in',
            'bytes_out',
            'mbps_in',
            'mbps_out',
            'max_mbps',
            'active',
            'app_name',
            'location',
        ])
        self.periodicMetaKeys.extend([
            'connections',
            'stream_count',
            'cpu_usage',
            'load',
            'mem_usage',
            'max_disk',
            'bytes_in',
            'bytes_out',
            'mbps_in',
            'mbps_out',
            'active',
        ])
        self.updatableKeys.extend([
            'rtmp_play_enabled',
            'rtmp_publish_enabled',
            'rtmp_replication_enabled',
            'max_mbps',
        ])
        self.setDefaults({
            'up_time' : int(time.time()),
            'streams' : [],
            'public_ip' : '127.0.0.1',
            'private_ip' : '127.0.0.1',
            'rtmp_port' : 1935,
            'connections' : 0,
            'stream_count' : 0,
            'spread_name' : 'no_name',
            'local' : False,
            'cpu_usage' : '0%',
            'load' : 0.0,
            'mem_usage' : '0B',
            'max_disk' : '0%',
            'rtmp_publish_enabled' : True,
            'rtmp_play_enabled' : True,
            'rtmp_replication_enabled' : True,
            'bytes_in' : 0,
            'bytes_out' : 0,
            'mbps_in' : 0,
            'mbps_out' : 0,
            'max_mbps' : 1,
            'active' : False,
            'app_name' : 'app',
            'location' : 'sanfrancisco',
        })
        
    def setViewers(self, num):
        self.stream_count = num
        
    def add_stream(self, stream):
        if stream.name not in self.streams:
            self.streams.append(stream.name)
        
    def remove_stream(self, stream):
        if stream.name in self.streams:
            self.streams.remove(stream.name)

    def score(self):
        return 1.0/max(0.2, float(self.load))
        
class Cluster(MetaObject):
    
    def setup(self):
        self.metaKeys.extend([
            'nodes',
            'streams',
            'connections',
            'stream_count',
            'mbps_in',
            'mbps_out',
            'bytes_in',
            'bytes_out',
        ])
        self.setDefaults({
            'nodes' : 0,
            'streams' : 0,
            'connections' : 0,
            'stream_count' : 0,
            'mbps_in' : 0,
            'mbps_out' : 0,
            'bytes_in' : 0,
            'bytes_out' : 0,
        })

class Datalink(MetaObject):
    def setup(self):
        self.metaKeys.extend([
            'location',
            'mbps_out',
            'mbps_in',
            'max_mbps',
            'prefixes',
            'ip',
            'iface',
        ])
        self.periodicMetaKeys.extend([
            'mbps_out',
            'mbps_in',
        ])
        self.updatableKeys.extend([
            'prefixes',
        ])
        self.setDefaults({
            'location': 'sanfrancisco',
            'mbps_out': 0,
            'mbps_in': 0,
            'max_mbps': 10000,
            'prefixes': [('0.0.0.0/0')]
        })
        self.indexes.extend([
            'ips',
        ])
        self.rebuild_ips()
        self.nodes=set()

    def update(self, o, sender):
        MetaObject.update(self, o, sender)
        self.nodes.add(sender)
        if 'prefixes' in o:
            self.rebuild_ips()

    def rebuild_ips(self):
        self.ips = utils.IPTrie()
        for prefix in self.prefixes:
            self.ips.add(prefix) 

    def __contains__(self, ip):
        return ip in self.ips
 
class Replication(MetaObject):

    def setup(self):
        self.metaKeys.extend([
            'stream',
            'source',
            'destination',
            'state',
        ])
        self.periodicMetaKeys.extend([
            'state',
        ])
        self.setDefaults({
            'stream' : None,
            'source' : None,
            'destination' : None,
            'state' : 'pending',
            'core' : None
        })
        self.client = None
        self.objects = []
        self.streamId = 2
        self.connection = None
        self.replicated = False
        self.attempts = 0
        
    def process(self, stream, packets, ignoreId=False, stop=False):
        # Prevent processing before connection is completed
        if not hasattr(self.connection, 'encoder'): return
        #log.msg('Replication got %s packets for stream %s' % (len(packets), stream.name))
        for packet in packets:                
            if packet['absflag'] or ignoreId: 
                tc = packet['abstime']
                absflag = True
            elif packet['object'] not in self.objects:
                tc = packet['abstime']
                absflag = True
                log.msg("Forced absolute timestamp for objectId %s" % packet['object'])
                self.objects.append(packet['object'])
            else:
                tc = packet['timecode']
                absflag = False                
            data = self.connection.encoder.process(packet['data'], packet['object'], self.streamId, packet['type'], tc, absflag)
            self.connection.write(data)
        
    def reconnect(self):
        "Erase the old replication and create a new one"
        if self.attempts < 3:
            reactor.callLater(3.0, self.doconnect)
            self.attempts += 1
        else:
            log.msg('Failed to create replication %s after 3 attempts.' % self.name)
            self.core.update(self, messages.DOWN)
        
    def doconnect(self):
        "Add the replication and try again"
        self.connect()

    def connect(self):
        "Make an RTMP connection to the destination"
        if self.destination in self.core.nodes:
            self.ip = self.core.nodes[self.destination].public_ip
            self.port = int(self.core.nodes[self.destination].rtmp_port)
            log.msg('Found replication destination in node list: using %s:%s' % (self.ip, self.port))
        else:
            self.destination = self.destination.lower().replace('rtmp://', '')
            parts = self.destination.split(':')
            if len(parts) > 1:
                self.ip = ':'.join(parts[:-1])
                self.port = int(parts[-1])
            else:
                self.ip = self.destination
                self.port = 1935
            log.msg('Could not find replication destination in node list: using %s:%s' % (self.ip, self.port))
        log.msg('Replicating stream %s to %s:%s' % (self.stream, self.ip, self.port))
        self.client = flvtools.factories.rtmp.RTMPServer(self.core, client=True)
        self.client.connectDeferred.addCallback(self.create_stream)
        reactor.connectTCP(self.ip, self.port, self.client)
        self.timeout_deferred = reactor.callLater(60.0, self.timeout)
        self.state = 'connecting'
        
    def create_stream(self, result=None):
        try:
            self.connection = self.client.get_nc()
            log.msg("Sending createStream to %s" % self.destination)
            self.client.invoke('createStream', self.streamId)
            self.client.createStreamDeferred.addCallback(self.publish)
            self.state = 'sending_create_stream'
        except:
            log.msg('Error with replication!')
            log.err()
            self.timeout()
        
    def publish(self, result=None):
        try:
            stream_name = self.core.config['rtmp_replication_key'] + self.stream
            log.msg("Sending publish to %s (stream %s)" % (self.destination, self.stream))
            self.client.invoke('publish', 0.0, [stream_name, 'live'], streamId=self.streamId)
            self.client.publishDeferred.addCallback(self.subscribe)
            self.state = 'sending_publish'
        except:
            log.msg('Error with replication!')
            log.err()
            self.timeout()
        
    def subscribe(self, result=None):
        log.msg("Subscribing to source stream %s" % self.stream)
        try:
            self.core.streams[self.stream].subscribe(self)
            self.state = 'streaming'
        except:
            log.msg('Unable to subscribe to stream %s:' % self.stream)
            log.err()
            self.timeout()
        try:
            self.timeout_deferred.cancel()
        except: pass
        
    def timeout(self, result=None):
        log.msg("Replication timed out at state %s: %s" % (self.state, self.name))
        self.state = 'failed with state %s' % self.state
        try:
            self.client.get_nc().shutdown()
        except: pass
        reactor.callLater(3.0, self.reconnect)
        try:
            self.timeout_deferred.cancel()
        except: pass
        
class Import(MetaObject):

    def setup(self):
        self.metaKeys.extend([
            'stream',
            'source',
            'source_stream',
            'destination',
            'state',
        ])
        self.periodicMetaKeys.extend([
            'state',
        ])
        self.setDefaults({
            'stream' : None,
            'source' : None,
            'source_stream' : None,
            'destination' : None,
            'state' : 'pending',
            'core' : None
        })
        self.client = None
        self.objects = []
        self.myStreamId = 2
        self.streamId = None
        self.connection = None

    def process(self, stream, packets, ignoreId=False, stop=False):
        if not hasattr(self.connection, 'encoder'): return
        for packet in packets:
            if packet['absflag']: tc = packet['abstime']
            elif packet['object'] not in self.objects:
                tc = packet['abstime']
                log.msg("Forced absolute timestamp for objectId %s" % packet['object'])
                self.objects.append(packet['object'])
            else:
                tc = packet['timecode']
            data = self.connection.encoder.process(packet['data'], packet['object'], self.streamId, packet['type'], tc, packet['absflag'])
            self.connection.write(data)

    def connect(self):
        "Make an RTMP connection to the source server"
        server = self.source.lower().replace('rtmp://', '').split('/')[0]
        parts = server.split(':')
        if len(parts) > 1:
            self.ip = ':'.join(parts[:-1])
            self.port = int(parts[-1])
        else:
            self.ip = server
            self.port = 1935
        log.msg('Importing stream %s from %s:%s to %s as %s' % (self.source_stream, self.ip, self.port, self.destination, self.stream))
        app = '/'.join(self.source.replace('rtmp://', '').split('/')[1:-1])
        tcUrl = '/'.join(self.source.split('/')[:-1])
        self.client = flvtools.factories.rtmp.RTMPServer(self.core, client=True, app=app, tcUrl = tcUrl)
        self.client.connectDeferred.addCallback(self.create_stream)
        reactor.connectTCP(self.ip, self.port, self.client)
        self.timeout_deferred = reactor.callLater(10.0, self.timeout)
        self.state = 'connecting'

    def create_stream(self, result=None):
        self.connection = self.client.get_nc()
        self.client.get_nc().server_bw('\x00\x13\x12\xd0')
        log.msg("Sending createStream to %s" % self.source)
        self.client.get_nc().invoke('createStream', self.myStreamId)    
        self.client.createStreamDeferred.addCallback(self.play)
        self.client.get_nc().ping('\x00\x03\x00\x00\x00\x00\x00\x00\x00d')
        self.state = 'sending_create_stream'

    def play(self, result=None):
        self.streamId = int(result[3])
        if self.streamId not in self.client.get_nc().streams:
            self.client.get_nc().create_netstream(self.streamId)
        log.msg("Sending play to %s (stream %s)" % (self.source, self.source_stream))
        self.client.get_nc().invoke('play', 0.0, [self.source_stream, -2000.0], streamId=self.streamId)
        self.client.get_nc().ping('\x00\x03\x00\x00\x00\x01\x00\x00\x00d')
        self.client.get_nc().streams[self.streamId].publish(self.stream)
        self.client.playDeferred.addCallbacks(self.success, self.failure)
        self.state = 'sending_play'
        
    def success(self, result=None):
        log.msg('Reading stream %s from %s' % (self.source_stream, self.source))
        self.state = 'streaming'
        self.timeout_deferred.cancel()
        
    def failure(self, result=None):
        log.msg('Import %s failed (stream not found)' % self.name)
        self.state = 'stream_not_found'

    def timeout(self, result=None):
        log.msg("Import timed out at state %s: %s" % (self.state, self.name))
        self.state = 'failed with state %s' % self.state
        
