# from collections import OrderedDict import time
import datetime
import time
from _version import __version__
from classes.branches.branch import get_element_base_name

from threading import Timer
from classes import logger as log
from gi.repository import Gst

STATS_INTERVAL = 1
_stats = {}
_callbacks = []
_t = 0
_running = False
_interval_timer = None
_stats_queue = []
_dirty_pipeline = None
_queues = {}
_videorates = {}


def is_other_pipeline(name):
    if "." in name:
        log.debug("%s is not in our pipeline, skipping", name)
        return True
    False


def add_queue(queue, oldstats, old_queues):
    name = queue.get_name()
    if is_other_pipeline(name):
        return
    _queues[name] = queue
    keys = ["%s_current_level_buffers_nr" % name,
            "%s_current_level_time_nr" % name]
    base_name = get_element_base_name(name)
    if base_name:
        keys.append("%s_current_level_buffers_nr" % base_name)
        keys.append("%s_current_level_time_nr" % base_name)
    for k in keys:
        _stats[k] = oldstats.get(k, 0)


def add_videorate(videorate, oldstats, old_videorates):
    name = videorate.get_name()
    if is_other_pipeline(name):
        return
    _videorates[name] = videorate
    if not hasattr(videorate, "_stats"):
        log.debug('{} did not have stats'.format(name))
        videorate._stats = {}
    # if old_videorates.get(name) and hasattr(old_videorates[name], "_stats"):


def _get_queue_stats():
    global _stats
    for (name, q) in _queues.items():
        nr = q.get_property("current-level-buffers")
        ms = q.get_property("current-level-time") // 1000000
        _stats["%s_current_level_buffers_nr" % name] = nr
        _stats["%s_current_level_time_nr" % name] = ms
        base_name = get_element_base_name(name)
        if base_name:
            _stats["%s_current_level_buffers_nr" % base_name] = nr
            _stats["%s_current_level_time_nr" % base_name] = ms


VIDEORATE_STATS = ['in', 'out', 'duplicate', 'drop']


def _get_videorate_stats():
    global _stats
    t = time.perf_counter()
    # normalize to / sec

    for (name, v) in _videorates.items():
        last = v._stats.get("time", 0)
        v._stats["time"] = t
        period = t - last
        base_name = get_element_base_name(name)

        for prop in VIDEORATE_STATS:
            value = v.get_property(prop)
            oldvalue = v._stats.get(prop, 0)
            v._stats[prop] = value
            if not last:
                continue
            calculated = (value - oldvalue) / period
            calculated = round(calculated, 2)
            _stats["%s_%s_nr" % (name, prop)] = calculated
            if base_name:
                _stats["%s_%s_nr" % (base_name, prop)] = calculated


def _update_elements():
    global _queues, _videorates, _dirty_pipeline, _stats
    if not _dirty_pipeline:
        return
    f = Gst.ElementFactory.find("queue")
    queue_type = f.get_element_type()
    f = Gst.ElementFactory.find("videorate")
    videorate_type = f.get_element_type()
    old_videorates = _videorates
    _videorates = {}
    old_queues = _queues
    _queues = {}
    # TODO: this breaks static stats:
    oldstats = _stats
    _stats = {}
    for i in _dirty_pipeline.iterate_recurse():
        try:
            if videorate_type.is_a(i):
                add_videorate(i, oldstats, old_videorates)
            elif queue_type.is_a(i):
                add_queue(i, oldstats, old_queues)
        except Exception as e:
            log.error("QOS element iteration %s", e, exc_info=True)
    _dirty_pipeline = None


def _send_stats():
    try:
        _update_elements()
        _get_queue_stats()
        _get_videorate_stats()
    except Exception as e:
        log.error('Unhandled Exception: {}'.format(e), exc_info=True)

    global _t, _stats, _stats_queue, _interval_timer
    for (interval, cb) in _callbacks:
        if int(_t) % interval == 0:
            try:
                val = cb()
                if not val:
                    log.warning('QoS callback returned nothing {}'.format(cb))
                else:
                    assert(type(val) == dict)
                    add_stats(val)
            except Exception as e:
                try:
                    log.error('Unhandled Exception: {}'.format(e), exc_info=True)
                except Exception as e:
                    # Rough native libraries like NVML may not support hashing
                    # on their native exception wrappers.
                    log.error('Unable to log qos exception.', exc_info=True)
    ts = _t
    _t += STATS_INTERVAL
    n = max(STATS_INTERVAL * 0.5, _t - time.time())
    _stats['client_time_dttm'] = datetime.datetime.utcfromtimestamp(ts).replace(microsecond=0).isoformat()  + "Z"
    _stats['event_dttm'] = datetime.datetime.utcfromtimestamp(ts).replace(microsecond=0).isoformat()  + "Z"
    _stats['version_tx'] = __version__
    _interval_timer = Timer(n, _send_stats)
    # log.debug("QOS %s", dict(_stats))
    _stats_queue.append(dict(_stats))
    _interval_timer.start()


def start():
    global _stats, _t, _interval_timer, _running, _stats_queue, _dirty
    _stats = {"action_tx": "report",
              "category_tx": "qos",
              "label_tx": "mercy",
              "routing_key": "qos.report.mercy"}
    _stats_queue = []
    _t = time.time()
    _dirty = True
    if _interval_timer is not None:
        _interval_timer.cancel()
    _interval_timer = Timer(STATS_INTERVAL, _send_stats)
    _interval_timer.start()


def mark_dirty(pipeline):
    global _dirty_pipeline
    _dirty_pipeline = pipeline


def stop():
    global _running, _stats_queue
    _stats_queue = []
    _running = False
    if _interval_timer is not None:
        _interval_timer.cancel()


def get_stats():
    global _stats_queue
    stats = _stats_queue
    _stats_queue = []
    return stats


def add_stats(stats):
    global _stats
    _stats.update(stats)


def remove_stats(keys):
    for key in keys:
        global _stats
        if key in _stats:
            _stats.pop(key)


def add_cb(cb, interval=1):
    if cb:
        _callbacks.append((interval, cb))
    else:
        log.warning('Cannot add null callback.', exc_info=True)


def remove_cb(cb):
    global _callbacks
    _callbacks = [(interval, callback) for (interval, callback) in _callbacks if callback is not cb]



# test_i = 0


# def test_stats():
#     log.info("test_stats")
#     global test_i
#     test_i += 1
#     add_stats(dict(foo_current_buffer_time_nr=123, i=test_i))


# add_cb(interval=1, cb=test_stats)
