#include "ElementaryStream.hpp"
#include "TransportStream.hpp"
#include "debug/trace.hpp"
#include <algorithm>
#include <cassert>
#include <cstring>
#include <limits>

namespace twitch {
namespace media {
ElementaryStream::ElementaryStream(uint16_t pid, uint8_t type, TransportStream* ts)
    : m_type(type)
    , m_pid(pid)
    , m_remaining(0)
    , m_ts(ts)
    , m_prevDts(TSTimestampUndefined)
    , m_firstDts(TSTimestampUndefined)
    , m_currentDts(0)
    , m_continuityCounter(0)
{
}

bool ElementaryStream::checkContinuityCounter(int8_t continuityCounter)
{
    int nextCounter = 0 <= m_continuityCounter ? (m_continuityCounter + 1) & 0x0F : continuityCounter;
    m_continuityCounter = continuityCounter;
    if (continuityCounter != nextCounter) {
        TRACE_INFO("PID %d continuity counter mismatch %d != %d", m_pid, continuityCounter, nextCounter);
        // treat as a discontinuity
        flush();
        reset();
        return false;
    }
    return true;
}

void ElementaryStream::emitFrame(const std::shared_ptr<MediaSampleBuffer>& frame)
{
    m_ts->emitFrame(*this, frame);
}

// Calculates different between transport stream timestamps accounting for 33bit rollover
int64_t ElementaryStream::timestampDelta(int64_t previous, int64_t timestamp)
{
    assert(TSTimestampMax >= previous);
    assert(TSTimestampMax >= timestamp);

    int64_t delta = timestamp - previous;
    if (delta < -(TSTimestampMax >> 1) || delta > (TSTimestampMax >> 1)) {
        delta = (TSTimestampMax - previous) + timestamp;
        TRACE_DEBUG("PID %d Timestamp rollover %lld => %lld, %lld", m_pid, previous, timestamp, delta);
    }

    return delta;
}

void ElementaryStream::startPes(int64_t pts, int64_t dts, int data_alignment)
{
    int64_t cts = 0;

    if (TSTimestampUndefined >= dts) {
        TRACE_WARN("Undefined DTS, Using previous: %lld", m_prevDts);
        pts = dts = m_prevDts;
    } else if (TSTimestampUndefined >= pts) {
        TRACE_WARN("Invalid PTS, Using DTS: %lld", dts);
        pts = dts;
    }

    cts = timestampDelta(dts, pts);
    if (0 > cts || maxCts() < cts) {
        TRACE_DEBUG("Invalid CTS: %lld - %lld = %lld", pts, dts, cts);
        cts = 0;
    }

    if (TSTimestampUndefined >= m_prevDts) {
        m_firstDts = dts;
        m_currentDts = dts;
    } else {
        int64_t delta = timestampDelta(m_prevDts, dts);

        if (minDtsDelta() > delta) {
            // for small DTS, we will just nudge the timestamp when we calculate m_mediaDuration
            TRACE_WARN("Excessively small DTS delta. pid: %d type: 0x%02x", m_pid, m_type);
            TRACE_WARN("%lld - %lld = %lld", dts, m_prevDts, delta);
            if (TSTimestampMax == dts) {
                TRACE_WARN("Excessively DTS delta likely result of mute bug. Ignoring frame.");
                return;
            }
        } else if (maxDtsDelta() < delta) {
            TRACE_WARN("Excessively large DTS delta. pid: %d type: 0x%02x", m_pid, m_type);
            TRACE_WARN("%lld - %lld = %lld => %lld", dts, m_prevDts, delta, DefaultFrameDuration);
            delta = DefaultFrameDuration; // we want the same DEFAULT_FRAME_DURATION on all es
            if (TSTimestampMax == dts) {
                TRACE_WARN("Excessively DTS delta likely result of mute bug. Ignoring frame.");
                return;
            }
        }

        m_currentDts += std::max(static_cast<int64_t>(0), delta);
        finishFrame(m_currentDts - m_firstDts);
    }

    m_prevDts = dts;
    startFrame(dts - m_firstDts, static_cast<int32_t>(cts), data_alignment);
}
}
}
