#include "ElementaryStreamAvc.hpp"
#include "debug/trace.hpp"

namespace twitch {
namespace media {
/*
 * Quoted from ISO/IEC 13818-1/2000/FDAM-3 (N5771):
 * "2.14 Carriage of ITU-T Rec. H.264 | ISO/IEC 14496-10 Video
 * The ITU-T Rec. H.264 | ISO/IEC 14496-10 coded data shall comply with the byte stream format defined
 *  in Annex B of ITU-T Rec. H.264 | ISO/IEC 14496-10, with the following constraints:
 * • Each AVC access unit shall contain an access_unit_delimiter NAL Unit.
 * • Each byte stream NAL Unit containing the access unit delimiter contains one zero_byte preceding
 *   the start_code_prefix_one_3bytes, as required by ITU-T Rec. H.264 | ISO/IEC 14496-10.
 * • Sequence and Picture Parameter Sets (SPS and PPS) shall be present within each AVC video stream
 *   carried in  Transport and Program streams (Note that ITU-T Rec. H.264 | ISO/IEC 14496-10 allows the
 *   delivery of SPS and PPS by external means).
 * • To provide display specific information such as aspect_ratio, it is strongly recommended that the
 *   each AVC video stream carries VUI messages with sufficient information to ensure that the decoded
 *   AVC video stream can be displayed correctly by receivers."
*/
ElementaryStreamAvc::ElementaryStreamAvc(uint16_t pid, TransportStream* ts)
    : ElementaryStream(pid, TypeAVC, ts)
    , m_outputDts(MediaTime::invalid())
    , m_droppedFrameDuration(0)
{
}

ElementaryStreamAvc::~ElementaryStreamAvc()
{
    reset();
}

void ElementaryStreamAvc::reset()
{
    m_buffer.clear();
    m_frameA.reset();
    m_frameB.reset();
    m_frameBuf.clear();
    m_extradata.clear();
    m_droppedFrameDuration = MediaTime(0, TSTimeScale);
    m_outputDts = MediaTime::invalid();
}

void ElementaryStreamAvc::startFrame(int64_t dts, int32_t cts, int data_alignment)
{
    if (data_alignment) {
        flushFrame();
    }

    if (m_frameB) {
        TRACE_WARN("ElementaryStream_Avc() PES before AUD");
        //Do we need to do anything else here?
    }

    m_frameB = std::make_shared<MediaSampleBuffer>();
    m_frameB->decodeTime = MediaTime(dts, TSTimeScale);
    m_frameB->presentationTime = MediaTime(dts + cts, TSTimeScale);
    m_frameB->duration = MediaTime(DefaultFrameDuration, TSTimeScale);
}

void ElementaryStreamAvc::finishFrame(int64_t dts)
{
    if (m_frameA) {
        auto duration = MediaTime(dts, TSTimeScale) - m_frameA->decodeTime;

        if (duration.scaleTo(TSTimeScale).count() > maxDtsDelta()) {
            TRACE_DEBUG("ElementaryStream_Avc::finishFrame() TS_MAX_DTS_DELTA");
        }

        m_frameA->duration = duration;
    } else {
        TRACE_ERROR("ElementaryStream_Avc::finishFrame called without active frame");
    }
}

void ElementaryStreamAvc::flush()
{
    flushFrame();

    if (m_frameB) {
        finishFrame();
    }

    if (m_frameA) {
        finishFrame();
    }
}

// Buffer data until we have a complete nalu, or flushFrame() is called
// It is assumed the data belongs to frame A.
static const size_t NALU_MAX_SIZE = 4 * 1024 * 1024; //4Mb
void ElementaryStreamAvc::addData(const uint8_t* data, size_t size)
{
    size_t scpos, sclen, scprev = m_buffer.size();
    m_buffer.insert(m_buffer.end(), data, data + size);

    if (NALU_MAX_SIZE < m_buffer.size()) {
        TRACE_WARN("ElementaryStream_Avc() nalu over %d bytes. clearing buffer.", NALU_MAX_SIZE);
        m_buffer.clear();
    }

    for (;;) {
        scpos = AVCParser::findStartCodeIncremental(m_buffer.data(), m_buffer.size(), scprev, &sclen);

        if (std::numeric_limits<size_t>::max() == scpos) {
            // This is a hack to peek the next NALU looking for an AUD
            // This is required to fix a bug in the current optimize ts handling
            if (2 <= m_buffer.size() && AVCParser::NalTypeAUD == (m_buffer[0] & 0x1f)) {
                scpos = 2;
                sclen = 0;
            } else {
                break;
            }
        }

        assert(scpos + sclen <= m_buffer.size());
        // copy the tail of the buffer to a new buffer. This will be smaller that 184 bytes
        std::vector<uint8_t> nalu(m_buffer.begin() + scpos + sclen, m_buffer.end());
        m_buffer.swap(nalu);
        nalu.resize(scpos); // swap the new buffer with the nalu, then resize to correct length
        addNalu(std::move(nalu)); // move the nalu to the frame
        scprev = 0; // Start search from begining of data
    }
}

void ElementaryStreamAvc::addNalu(const std::vector<uint8_t>& nalu)
{
    if (nalu.size()) {
        int nal_unit_type = (nalu[0] & 0x1F);

        if (AVCParser::NalTypeAUD == nal_unit_type && m_frameB) {
            finishFrame();
        }

        if (!m_frameA) {
            TRACE_WARN("ElementaryStream_Avc() Expected AUD (9) Received %d", nal_unit_type);
        } else {
            if (AVCParser::NalTypeAUD != nal_unit_type) {
                m_frameBuf.addNalu(nalu);
            }
        }
    }
}

// call flushFrame() when we know the buffer completes the frame, but haven't seen the next AUD yet
// This happens when the data alignment indicator is set, or flush is called
void ElementaryStreamAvc::flushFrame()
{
    addNalu(m_buffer);
    m_buffer.clear();
}

void ElementaryStreamAvc::finishFrame()
{
    if (m_frameA) {
        if (m_frameBuf.empty()) {
            TRACE_WARN("ElementaryStream_Avc Frame contains no nalus");
            m_droppedFrameDuration += m_frameA->duration;
        } else if (!m_frameBuf.isVideoCodingLayer()) {
            m_frameBuf.clear();
            TRACE_WARN("Frame contains no VCL NALUs");
            return; // return so that m_frameB is not promoted
        } else if (!m_frameBuf.isSyncSample() && m_extradata.empty()) {
            TRACE_WARN("ElementaryStream_Avc Waiting for sync frame");
            m_droppedFrameDuration += m_frameA->duration;
        } else {
            if (MediaTime::zero() < m_droppedFrameDuration) {
                m_frameA->decodeTime -= m_droppedFrameDuration;
                m_frameA->duration += m_droppedFrameDuration;
                TRACE_WARN("ElementaryStream_Avc Adjusting timestamps due to errant frames %llddms dts: %lldd",
                    m_droppedFrameDuration.milliseconds().count(), m_frameA->decodeTime.milliseconds().count());
                m_droppedFrameDuration = MediaTime::zero();
            }

            if (m_outputDts.valid() && m_outputDts != m_frameA->decodeTime) {
                TRACE_WARN("ElementaryStream_Avc m_outputDts(%lld) != m_frameA->dts(%lld)",
                    m_outputDts.milliseconds().count(), m_frameA->decodeTime.milliseconds().count());
            }

            if (m_extradata.empty()) {
                m_frameA->presentationTime = m_frameA->decodeTime;
                m_extradata = AVCParser::getExtradataFromFrame(m_frameBuf);
            }
            m_frameA->isSyncSample = m_frameBuf.isIDRSample();
            m_frameA->buffer.swap(m_frameBuf);
            m_outputDts = m_frameA->decodeTime + m_frameA->duration;
            emitFrame(m_frameA);
        }

        m_frameA.reset();
    }

    // Promote frame B. m_frameA is guaranteed to be empty
    m_frameA.swap(m_frameB);
    m_frameBuf.clear();
}
}
}
