#include "ElementaryStreamAac.hpp"
#include "debug/trace.hpp"
#include <cmath>

namespace twitch {
namespace media {
ElementaryStreamAac::ElementaryStreamAac(uint16_t pid, TransportStream* ts)
    : ElementaryStream(pid, TypeAAC, ts)
    , m_bufferDts(TSTimestampUndefined)
    , m_sampleCount(0)
    , m_frequency(0)
    , m_element_instance_tag(-1)
{
}

ElementaryStreamAac::~ElementaryStreamAac()
{
    reset();
}

void ElementaryStreamAac::reset()
{
    aac_init_adts(&m_adts);
    m_frequency = 0;
    m_sampleCount = 0;
    m_element_instance_tag = 0;
    m_bufferDts = TSTimestampUndefined;
    m_buffer.clear();
}

void ElementaryStreamAac::startFrame(int64_t dts, int32_t cts, int data_alignment)
{
    if (data_alignment && !m_buffer.empty()) {
        flush();
    }

    if (!m_buffer.empty()) {
        TRACE_WARN("ElementaryStream_Aac::startFrame() called with non empty buffer");
    }

    (void)cts;
    m_bufferDts = dts;

    if (remaining() + m_buffer.size() > m_buffer.capacity()) {
        m_buffer.reserve(static_cast<size_t>(1.5 * (remaining() + m_buffer.size())));
    }
}

void ElementaryStreamAac::addSilentFrames(int count)
{
    if (0 >= count) {
        return;
    }

    TRACE_WARN("+++ Inserting %d aac frames", count);

    for (int i = 0; i < count; ++i) {
        auto frame = std::make_shared<MediaSampleBuffer>();
        frame->decodeTime = MediaTime(m_sampleCount, m_frequency);
        frame->presentationTime = frame->decodeTime;
        frame->duration = MediaTime(AAC_SAMPLES_PER_FRAME, m_frequency);
        frame->isSyncSample = true;
        frame->buffer = aac_silent_frame(m_adts.channelConfig, m_adts.frequencyIndex, m_element_instance_tag);
        m_sampleCount += AAC_SAMPLES_PER_FRAME; // count the samples we just inserted
        emitFrame(frame);
    }
}

void ElementaryStreamAac::finishFrame(int64_t dts)
{
    (void)dts;
    flush(); // flush whatever is in the buffer
}

void ElementaryStreamAac::addData(const uint8_t* data, size_t size)
{
    if (TSTimestampUndefined >= m_bufferDts) {
        TRACE_WARN("Audio data of unknown PTS %lld", m_bufferDts);
        return;
    }

    m_buffer.insert(m_buffer.end(), data, data + size);

    if (!remaining()) {
        flush();
    }
}

void ElementaryStreamAac::flush()
{
    // make s signed so negative wrap arounds can be detected
    int32_t s = static_cast<int32_t>(m_buffer.size());
    uint8_t* d = m_buffer.data();

    while (AAC_DEFAULT_HEADER_SIZE <= s) {
        // test for ADTS sync word
        aac_parse_adts(&m_adts, d, s);

        if (AAC_SYNC_WORD != m_adts.syncWord) {
            TRACE_WARN("ElementaryStream_Aac() invalid sync word");
            d += 1;
            s -= 1; // move forward by one byte in an attempt to resync
        } else if (0 > aac_sanity(&m_adts)) {
            TRACE_WARN("ElementaryStream_Aac() aac fails basic sanity checks");
            TRACE_WARN("ADTS error: %d %d %d %d %d %d", aac_sanity(&m_adts), aac_header_size(&m_adts), m_adts.frameSize, aac_frequency(&m_adts), aac_channels(&m_adts), m_adts.frameCount);
            d += m_adts.frameSize;
            s -= m_adts.frameSize;
        } else if (s < static_cast<int32_t>(m_adts.frameSize)) { // Check if payload is at least as large as this frame
            TRACE_WARN("ElementaryStream_Aac() payload size (%d) less than frame_size (%u)", s, m_adts.frameSize);
            s = 0;
        } else {
            unsigned headerSize = aac_header_size(&m_adts);

            // Validate element_instance_tag
            auto element_instance_tag = aac_element_instance_tag(d + headerSize);

            if (0 <= m_element_instance_tag && 0 <= element_instance_tag && m_element_instance_tag != element_instance_tag) {
                TRACE_DEBUG("Mismatch element_instance_tag");
            }

            if (0 < m_frequency && m_frequency != aac_frequency(&m_adts)) {
                m_sampleCount = 0;
                TRACE_WARN("Audio frequency change");
            }

            m_frequency = aac_frequency(&m_adts);
            m_element_instance_tag = element_instance_tag;

            // Retimestamp
            int64_t sampleCountHint = (m_bufferDts * m_frequency) / TSTimeScale;

            if (0 >= m_sampleCount) {
                m_sampleCount = sampleCountHint;
            }

            int64_t delta = sampleCountHint - m_sampleCount;
            static const int AAC_DRIFT_MAX = static_cast<int>(2.0 * AAC_SAMPLES_PER_FRAME);

            // The human brain is less sensitive to trailing audio than leading audio
            if (-delta > AAC_DRIFT_MAX) {
                // Drop audio frame if delta is too old
                TRACE_DEBUG("--- Dropping aac frame %lld < %lld : %lld (pts: %lld)", sampleCountHint, m_sampleCount, delta, m_bufferDts);
            } else {
                if (delta > AAC_DRIFT_MAX) {
                    addSilentFrames(static_cast<int>(delta / AAC_SAMPLES_PER_FRAME));
                }

                {
                    auto frame = std::make_shared<MediaSampleBuffer>();
                    frame->decodeTime = MediaTime(m_sampleCount, m_frequency);
                    frame->presentationTime = frame->decodeTime;
                    frame->duration = MediaTime(m_adts.sampleCount, m_frequency);
                    frame->isSyncSample = true;
                    frame->buffer.assign(d + headerSize, d + m_adts.frameSize);
                    m_sampleCount += m_adts.sampleCount;
                    emitFrame(frame);
                }
            }

            m_bufferDts += (AAC_SAMPLES_PER_FRAME * TSTimeScale) / m_frequency;
            d += m_adts.frameSize;
            s -= m_adts.frameSize;
        }
    }

    m_buffer.clear();

    if (0 < s) {
        TRACE_WARN("ElementaryStream_Aac() Skipped %u bytes", s);
    }
}

std::vector<uint8_t> ElementaryStreamAac::extradata() const
{
    std::vector<uint8_t> extradata(2);
    aac_render_extradata(&m_adts, extradata.data(), extradata.size());
    return extradata;
}
}
}
