#include "TransportStream.hpp"
#include "ElementaryStreamAac.hpp"
#include "ElementaryStreamAvc.hpp"
#include "ElementaryStreamId3.hpp"
#include "debug/trace.hpp"
#include <algorithm>
#include <cassert>
#include <cstring>
#include <limits>

namespace twitch {
namespace media {
const static uint8_t PacketSize = 188;
const static uint8_t PacketSyncByte = 0x47;

static uint64_t parsePTS(const uint8_t* data)
{
    // 0000 1110  1111 1111  1111 1110  1111 1111  1111 1110
    uint64_t pts = 0;
    pts |= (uint64_t)(data[0] & 0x0E) << 29;
    pts |= (uint64_t)(data[1] & 0xFF) << 22;
    pts |= (uint64_t)(data[2] & 0xFE) << 14;
    pts |= (uint64_t)(data[3] & 0xFF) << 7;
    pts |= (uint64_t)(data[4] & 0xFE) >> 1;
    return pts;
}

TransportStream::TransportStream(Listener& listener)
    : m_listener(listener)
    , m_pcr(0)
    , m_pcrext(0)
    , m_pmtpid(0)
    , m_pcrpid(0)
    , m_metapid(0)
    , m_videopid(0)
    , m_audiopid(0)
    , m_timeOffset(0)
    , m_mediaDuration(0)
{
}

TransportStream::~TransportStream()
{
    reset();
}

void TransportStream::finish()
{
    if (!m_buffer.empty()) {
        TRACE_WARN("TransportStream::finish() called with partial ts packet buffered");
    }
    for (auto& es : m_map) {
        // we save the base media decode times so we can seamlessly concatenate the next segment
        es.second->flush();
    }
    m_timeOffset = m_mediaDuration;
    reset();
}

void TransportStream::reset()
{
    for (auto& es : m_map) {
        es.second->reset();
    }

    m_map.clear();
    m_buffer.clear();

    m_pcr = 0;
    m_pcrext = 0;
    m_pmtpid = 0;
    m_pcrpid = 0;
    m_metapid = 0;
    m_videopid = 0;
    m_audiopid = 0;
}

void TransportStream::seek(int64_t timestamp)
{
    TransportStream::reset();
    m_timeOffset = timestamp;
    m_mediaDuration = timestamp;
}

void TransportStream::addData(const uint8_t* data, size_t size)
{
    m_buffer.insert(m_buffer.end(), data, data + size);
    data = m_buffer.data(), size = m_buffer.size();

    while (PacketSize <= size) {
        if (parsePacket(data)) {
            data += PacketSize;
            size -= PacketSize;
        } else {
            data += 1;
            size -= 1;
        }
    }

    m_buffer = std::vector<uint8_t>(data, data + size);
}

void TransportStream::emitFrame(const ElementaryStream& stream, const std::shared_ptr<MediaSampleBuffer>& frame)
{
    if (m_timeOffset) {
        MediaTime offset(m_timeOffset, TSTimeScale);
        frame->decodeTime += offset;
        frame->presentationTime += offset;
    }
    // use audio duration for media duration if present
    if (stream.pid() == m_audiopid) {
        m_mediaDuration += frame->duration.scaleTo(TSTimeScale).count();
    }
    // use video if no audio present
    else if ((m_audiopid == 0 && stream.pid() == videoPid()) || (m_audiopid == 0 && m_videopid == 0)) {
        m_mediaDuration += frame->duration.scaleTo(TSTimeScale).count();
    }

    m_listener.onElementarySample(stream.type(), frame);
}

bool TransportStream::parsePacket(const uint8_t* data)
{
    assert(data);

    if (PacketSyncByte != data[0]) {
        return false;
    }

    // assert(data[ 0] == 0x47);
    bool pusi = !!(data[1] & 0x40); // Payload Unit Start Indicator
    int16_t pid = ((data[1] & 0x1F) << 8) | data[2]; // PID

    // assert((data[i + 3] & 0xc0) == 0);      // Scrambling control.

    bool adaption_present = !!(data[3] & 0x20); // Adaptation field exist
    bool payload_present = !!(data[3] & 0x10); // Contains payload
    uint8_t continuityCounter = (data[3] & 0x0F);

    size_t i = 4;

    if (adaption_present) {
        uint8_t adaption_length = data[i + 0]; // adaption field length

        if (pid == m_pcrpid && 7 <= adaption_length && (0x10 & data[i + 1])) {
            // 11111111 11111111 11111111 11111111 1XXXXXX1 11111111
            m_pcr = 0;
            m_pcrext = 0;
            m_pcr |= (int64_t)(data[i + 2] & 0xFF) << (33 - 8);
            m_pcr |= (int64_t)(data[i + 3] & 0xFF) << (33 - 16);
            m_pcr |= (int64_t)(data[i + 4] & 0xFF) << (33 - 24);
            m_pcr |= (int64_t)(data[i + 5] & 0xFF) << (33 - 32);
            m_pcr |= (int64_t)(data[i + 6] & 0x80) >> 7;
            m_pcrext |= (int16_t)(data[i + 6] & 0x01) << 8;
            m_pcrext |= (int16_t)(data[i + 7] & 0xFF) << 0;
        }

        i += 1 + adaption_length;
    }

    if (pid == 0) {
        if (payload_present) {
            // Skip the payload.
            if (i >= PacketSize) {
                return false;
            }
            i += data[i] + 1;
        }

        // PAT
        // assert(data[i + 0] == 0);    // table id must be 0
        // assert((data[i + 1] & 0x80) == 0x80);      // section syntax indicator
        // assert((data[i + 1] & 0x40) == 0);      // private
        // assert((data[i + 1] & 0x30) == 0x30);      // reserved

        //uint16_t section_length = ((data[i + 1] & 0x0F) << 8) | data[i + 2];
        // assert(section_length == 13);    // we only support 1 program!

        // assert((data[i + 4] & 0xC0) == 0);      // reserved
        // assert((data[i + 5] & 0x01) == 0x01);      // section is current

        if (i + 11 >= PacketSize) {
            return false;
        }
        m_pmtpid = ((data[i + 10] & 0x1F) << 8) | data[i + 11];
    } else if (pid == m_pmtpid) {
        // PMT
        if (payload_present) {
            // Skip the payload.
            if (i >= PacketSize) {
                return false;
            }
            i += data[i] + 1;
        }

        // assert(data[i + 0] == 2);    // table id must be 2
        // assert((data[i + 1] & 0x80) == 0x80);      // section syntax indicator
        // assert((data[i + 1] & 0x40) == 0);      // private
        // assert((data[i + 1] & 0x30) == 0x30);      // reserved

        if (i + 11 >= PacketSize) {
            return false;
        }
        uint16_t section_length = ((data[i + 1] & 0x0F) << 8) | data[i + 2];
        bool current = data[i + 5] & 0x01;

        // assert((data[i + 4] & 0xC0) == 0);      // reserved
        // assert((data[i + 5] & 0x01) == 0x01);      // section is current
        // assert((data[i + 8] & 0xE0) == 0xE0);      // reserved

        m_pcrpid = ((data[i + 8] & 0x1F) << 8) | data[i + 9];
        // assert((data[i + 10] & 0xF0) == 0xF0);      // reserved

        int16_t program_info_length = ((data[i + 10] & 0x0F) << 8) | data[i + 11];
        int16_t descriptor_loop_length = section_length - (9 + program_info_length + 4); // 4 for the crc

        i += 12 + program_info_length;

        if (current) {
            while (descriptor_loop_length >= 5) {
                if (i + 4 >= PacketSize) {
                    return false;
                }
                uint8_t stream_type = data[i];
                int16_t elementary_pid = ((data[i + 1] & 0x1F) << 8) | data[i + 2];
                int16_t esinfo_length = ((data[i + 3] & 0x0F) << 8) | data[i + 4];

                switch (stream_type) {
                case ElementaryStream::TypeID3:
                    if (m_metapid && m_metapid != elementary_pid) {
                        TRACE_WARN("TransportStream() metadata pid changed (%d -> %d)", m_metapid, elementary_pid);
                    }

                    if (!m_metapid) {
                        TRACE_DEBUG("TransportStream() Found ID3 stream at PID %d", elementary_pid);
                        m_map[elementary_pid].reset(new ElementaryStreamId3(elementary_pid, this));
                        m_metapid = elementary_pid;
                    }
                    break;

                case ElementaryStream::TypeAVC:
                    if (m_videopid && m_videopid != elementary_pid) {
                        TRACE_WARN("TransportStream() video pid changed (%d -> %d)", m_metapid, elementary_pid);
                    }

                    if (!m_videopid) {
                        TRACE_DEBUG("TransportStream() Found AVC stream at PID %d", elementary_pid);
                        m_map[elementary_pid].reset(new ElementaryStreamAvc(elementary_pid, this));
                        m_videopid = elementary_pid;
                    }
                    break;

                case ElementaryStream::TypeAAC:
                    if (m_audiopid && m_audiopid != elementary_pid) {
                        TRACE_WARN("TransportStream() audio pid changed (%d -> %d)", m_metapid, elementary_pid);
                    }

                    if (!m_audiopid) {
                        TRACE_DEBUG("TransportStream() Found AAC stream at PID %d", elementary_pid);
                        m_map[elementary_pid].reset(new ElementaryStreamAac(elementary_pid, this));
                        m_audiopid = elementary_pid;
                    }
                    break;

                default:
                    TRACE_DEBUG("Unknown stream type %d (pid: %d)\n", stream_type, elementary_pid);
                    break;
                }

                i += 5 + esinfo_length;
                descriptor_loop_length -= 5 + esinfo_length;
            }
        }
        assert(descriptor_loop_length == 0);
    } else {
        const auto es = m_map.find(pid);

        if (es != m_map.end()) {
            if (payload_present) {
                if (!es->second->checkContinuityCounter(continuityCounter)) {
                    m_listener.onElementaryDiscontinuity(es->second->type());
                }
                if (pusi) {
                    int64_t pts = TSTimestampUndefined;
                    int64_t dts = TSTimestampUndefined;
                    // assert(data[i + 0] == 0);
                    // assert(data[i + 1] == 0);
                    // assert(data[i + 2] == 1);

                    if (i + 8 >= PacketSize) {
                        return false;
                    }
                    size_t pktlen = (data[i + 4] << 8) | data[i + 5];
                    // assert((data[i + 6] & 0xC0) == 0x80);      // marker bits
                    // assert((data[i + 6] & 0x30) == 0);      // scrambling control

                    bool data_alignment = !!(data[i + 6] & 0x04);
                    uint8_t pts_dts_indicator = (data[i + 7] & 0xC0) >> 6;
                    uint8_t header_length = data[i + 8];
                    i += 9;

                    switch (pts_dts_indicator) {
                    case 3:
                        if (i + 4 + 5 >= PacketSize) {
                            return false;
                        }
                        pts = parsePTS(data + i);
                        dts = parsePTS(data + i + 5);
                        break;
                    case 2:
                        if (i + 4 >= PacketSize) {
                            return false;
                        }
                        pts = parsePTS(data + i);
                        dts = pts;
                        break;
                    default:
                        TRACE_DEBUG("Frame timestamps unavailable. pts_dts_indicator %d", pts_dts_indicator);
                        break;
                    }

                    i += header_length;
                    pktlen -= (0 < pktlen) ? 3 + header_length : 0;

                    if (i >= PacketSize) {
                        return false;
                    }
                    es->second->setRemaining(pktlen);
                    es->second->startPes(pts, dts, data_alignment);
                }

                // whatever is left is payload
                if (ElementaryStream::TypeAVC == es->second->type()) {
                    // AVC does not track size
                    es->second->addData(data + i, PacketSize - i);
                } else {
                    size_t size = std::min(PacketSize - i, es->second->remaining());
                    es->second->setRemaining(es->second->remaining() - size);
                    es->second->addData(data + i, size);
                }
            }
        } else {
            TRACE_WARN("Unknown PID %d", pid);
        }
    }

    return true;
}
}
}
