import traceback
import sys
import random
import time

import gi
gi.require_version('Gst', '1.0')

from gi.repository import Gst, GObject, GLib
GObject.threads_init()
Gst.init(None)

from sagemaker.pytorch import PyTorch

TEST_GPU = True
TEST_COUNT = 1
TEST_SLEEP_DELAY = 1
TEST_ENCODE_PNG = False
TEST_FRAMERATE = 0


def link_many(*args):
    if not len(args):
        return
    for index, element in enumerate(args[:-1]):
        next_element = args[index+1]
        if not Gst.Element.link(element, next_element):
            raise TypeError('failed to link {}, {}'.format(element, next_element))


def make_queue(queue_name="", **kwargs):
    props = {k.replace("_", "-"): v for k, v in kwargs.items()}
    if queue_name:
        queue_name += "_"
    name = "%s_queue" % (queue_name)
    queue = Gst.ElementFactory.make("queue", name)
    queue.set_property("max-size-bytes", props.pop("max-size-bytes", 0))
    queue.set_property("max-size-buffers", props.pop("max-size-buffers", 0))
    queue.set_property("max-size-time", props.pop("max-size-time", 0))
    queue.set_property("silent", props.pop("silent", True))
    for k, v in props.items():
        queue.set_property(k, v)
    return queue

def make(element_name, name=None, **kwargs):
    props = {k.replace("_", "-"): v for k, v in kwargs.items()}
    element = Gst.ElementFactory.make(element_name, name)
    if not element:
        raise TypeError('Failed to create element: {}, {} {}'.format(element_name, name, props))
    for k, v in props.items():
        element.set_property(k, v)
    return element

def make_caps(caps_str):
    return make("capsfilter", caps=Gst.Caps.from_string(caps_str))



class Pipeline:
    def __init__(self, i):
        self.id = str(i)
        self.pipeline = None
        self.lt = 0

    def fakesink_handoff(self, _, buf, pad):
        now = self.pipeline.get_clock().get_time()
        print("pipeline id:", self.id, "time:", (now - self.lt) / 1e6)
        self.lt = now
        caps = pad.get_current_caps()
        caps_structure = caps.get_structure(0)
        width = caps_structure.get_int("width").value
        height = caps_structure.get_int("height").value
        (_result, mapinfo) = buf.map(Gst.MapFlags.READ)

        # _mapinfo.data # buffers

        buf.unmap(mapinfo)
        timestamp = buf.pts

    def on_decode_pad_added_gpu(self, element, pad):
        caps_string = pad.get_current_caps().to_string()
        if "video" in caps_string:
            print(self.id, "on_decode_pad_added_gpu", caps_string)
            videorate = make("videorate")
            glvideoconvert = make("identity")
            glcapsfilter = make("identity") if TEST_FRAMERATE < 1 else make_caps("video/x-raw(memory:CUDAMemory),framerate={}/1".format(TEST_FRAMERATE))
            gldownload = make("identity")
            fakesink_queue = make("identity") # make_queue("queue", max_size_buffers=2, leaky=2)

            if TEST_ENCODE_PNG:
                sink = make("pngenc")
            else:
                sink = make("fakesink", signal_handoffs=True, sync=True, silent=False)
                sink.connect("handoff", self.fakesink_handoff)

            elements = [videorate, glvideoconvert, glcapsfilter,
                        gldownload, fakesink_queue, sink]

            for e in elements:
                self.pipeline.add(e)

            pad.link(videorate.get_static_pad("sink"))

            link_many(*elements)

            for e in elements:
                e.set_state(Gst.State.PLAYING)
        elif "audio" in caps_string:
            pass

    def on_decode_pad_added(self, element, pad):
        caps_string = pad.get_current_caps().to_string()
        if "video" in caps_string:
            print(self.id, "on_decode_pad_added", caps_string)
            videorate = make("videorate")
            videoconvert = make("videoconvert")
            capsfilter = make_caps("video/x-raw,framerate=1/1,format=BGR")
            fakesink_queue = make_queue("queue", leaky=2)

            if TEST_ENCODE_PNG:
                sink = make("pngenc")
            else:
                sink = make("fakesink", signal_handoffs=True)
                sink.connect("handoff", self.fakesink_handoff)

            elements = [videorate, videoconvert, capsfilter,
                        fakesink_queue, sink]

            for e in elements:
                self.pipeline.add(e)

            pad.link(videorate.get_static_pad("sink"))

            link_many(*elements)

            for e in elements:
                e.set_state(Gst.State.PLAYING)
        elif "audio" in caps_string:
            pass

    def decodebin_ap_sort(self, bin, pad, caps, factories):
        caps_str = caps.to_string()
        # print(caps_str)
        should_sort = False
        for f in factories:
            if f.get_name() == "nvdec":
                should_sort = True

        if should_sort:
            x, y = factories[1], factories[2]
            factories[1] = y
            factories[2] = x

        return factories

    def start(self):
        self.pipeline = Gst.Pipeline.new("mirage_" + str(self.id))

        # Fortnite VOD
        uri="https://vod-secure.twitch.tv/83d1874d663a0a6c38a2_tfue_34991173360_1255618395/chunked/index-dvr.m3u8"
        f = make("uridecodebin", uri=uri)

        if TEST_GPU:
            f.connect("autoplug-sort", self.decodebin_ap_sort)
            f.connect("pad-added", self.on_decode_pad_added_gpu)
        else:
            f.connect("pad-added", self.on_decode_pad_added)

        src = [f]

        for s in src:
            self.pipeline.add(s)

        self.pipeline.set_state(Gst.State.PLAYING)
        print(self.id, "pipeline set to playing")




def main():
    print("Starting", TEST_COUNT, "pipelines. GPU:", TEST_GPU)

    for i in range(TEST_COUNT):
        p = Pipeline(i)
        p.start()
        time.sleep(TEST_SLEEP_DELAY)

    print("all streams running")


if __name__ == '__main__':
    try:
        main()

        loop = GLib.MainLoop()
        loop.run()
    except KeyboardInterrupt:
        print("stopping")
    except Exception:
        traceback.print_exc(file=sys.stdout)
        print("sh*t blows")
        sys.exit(1)


