#include "Mp4Reader.hpp"
#include "../avc/avcutil.hpp"
#include "../id3/id3.hpp"
#include "media/MemoryStream.hpp"
#include "media/SourceFormat.hpp"
#include "playercore/SecureSampleBuffer.hpp"
#include <algorithm>

namespace twitch {
namespace media {
Mp4Reader::Mp4Reader(Platform& platform, MediaReader::Listener& listener, const std::string& path)
    : m_listener(listener)
    , m_path(path)
    , m_parser(platform.getLog())
    , m_platform(platform)
    , m_log(platform.getLog())
    , m_nalLengthSize(4)
{
    reset();
}

void Mp4Reader::load()
{
    // empty buffer is end of stream
    if (!m_stream || !m_stream->length()) {
        m_listener.onMediaEndOfStream();
        return;
    }

    m_parser.setStream(m_stream.get());
    m_parser.readTracks();

    // check the state of the parser before queuing the samples
    if (m_parser.getTracks().empty()) {
        m_listener.onMediaError(Error(ErrorSource::File, MediaResult::ErrorInvalidData, "Failed loading mp4"));
        return;
    }

    m_haveHeader = true;
    // update track formats
    if (m_parser.isFragmented() ? m_parser.isInitializationFragment() : true) {
        initializeTracks();
    }

    // if seek called before the stream metadata was loaded set the seek point
    seekTo(m_seekTime);
}

void Mp4Reader::initializeTracks()
{
    m_formats.clear();
    m_trackSampleCounts.clear();
    m_selectedTracks.clear();

    for (const auto& track : m_parser.getTracks()) {
        auto format = createTrackFormat(*track);
        m_trackSampleCounts[track->getId()] = 0;

        if (format) {
            TrackId id = getStableTrackId(*track);
            m_formats[id] = format;
            m_listener.onMediaTrack(id, format);
            m_selectedTracks.push_back(track);
        }
    }

    if (m_selectedTracks.empty()) {
        m_listener.onMediaError(Error(ErrorSource::File, MediaResult::ErrorNotSupported, "No tracks supported"));
    }
}

void Mp4Reader::readSamples(MediaTime duration)
{
    if (m_parser.canReadSamples(m_selectedTracks, duration)) {
        m_samplesRead = 0;
        using namespace std::placeholders;
        MediaResult result = m_parser.readSamples(m_selectedTracks, std::bind(&Mp4Reader::handleTrackData, this, _1, _2), duration);
        if (result != MediaResult::Ok) {
            m_log->warn("Error reading MP4");
            return;
        }

        if (m_samplesRead == 0) {
            if (!m_streamEnded && !m_parser.isFragmented() && m_parser.isEnded()) {
                m_listener.onMediaEndOfStream();
                m_streamEnded = true;
            }
        } else {
            m_listener.onMediaFlush();
        }
    }
}

void Mp4Reader::seekTo(MediaTime time)
{
    if (m_haveHeader) {
        MediaResult result = m_parser.seekTo(time);
        if (result != MediaResult::Ok) {
            m_listener.onMediaError(Error(ErrorSource::Source, result, "Error seeking MP4"));
        }
    }
    m_seekTime = time;
    m_trackSampleCounts.clear();
    m_streamEnded = false;
}

void Mp4Reader::reset()
{
    m_trackSampleCounts.clear();
    m_streamEnded = false;
    m_samplesRead = 0;
    m_parser = Mp4Parser(m_log);
    resetParserStream();
    m_ceaCaptions.reset(new CEACaptions([this](const std::shared_ptr<MediaSampleBuffer>& sample) {
        createMetadataTrack();
        m_listener.onMediaSample(MediaReader::MetaTrackId, sample);
    }));
}

void Mp4Reader::resetParserStream()
{
    m_haveHeader = false;
    m_stream.reset(new MemoryStream());
    m_parser.setStream(m_stream.get());
}

void Mp4Reader::setStream(std::unique_ptr<Stream> stream)
{
    m_stream = std::move(stream);
    m_parser.setStream(m_stream.get());
    if (m_stream && m_stream->length()) {
        load();
    }
}

std::shared_ptr<const MediaFormat> Mp4Reader::getTrackFormat(TrackId id)
{
    return m_formats[id];
}

MediaTime Mp4Reader::getDuration() const
{
    return getBaseDecodeTime() + m_parser.getDuration();
}

void Mp4Reader::addData(const uint8_t* data, size_t size, bool endOfStream)
{
    if (size) {
        m_stream->seek(static_cast<size_t>(m_stream->length()));
        m_stream->write(data, size);

        if (m_stream->error()) {
            handleStreamError("Stream write failed");
        }
    }

    if (m_parser.isFragmented()) {
        while (m_parser.canReadTracks()) {
            m_parser.readTracks();

            if (m_parser.isInitializationFragment()) {
                initializeTracks();
                // don't need to incrementally parse init fragment
                break;
            }

            auto nextFragmentOffset = static_cast<int64_t>(m_parser.getNextFragmentOffset());
            bool hasNextFragment = nextFragmentOffset > 0 && nextFragmentOffset < m_stream->length();

            if (hasNextFragment) {
                readEmsgs();
                readSamples(MediaTime::max());
                // remove the data we've read
                auto memoryStream = static_cast<MemoryStream*>(m_stream.get());
                memoryStream->remove(nextFragmentOffset);
            } else {
                // don't have enough data
                break;
            }
        }
    } else if (!m_haveHeader && m_stream->length() > 0) {
        if (m_parser.canReadTracks()) {
            // creates the media format from the parsed track data
            load();
        }
    }

    // if fragmented mp4 operate as a push source, read out all samples if the fragment is complete
    if (m_parser.isFragmented() && endOfStream) {
        if (!m_parser.isInitializationFragment()) {
            readEmsgs();
            readSamples(MediaTime::max());
        }
        // prepare for reading a new fragment
        resetParserStream();
    }
}

void Mp4Reader::onDiscontinuity(uint32_t flags)
{
    (void)flags;
}

void Mp4Reader::handleStreamError(const std::string& message)
{
    Error error(ErrorSource::Source, MediaResult(MediaResult::ErrorInvalidState, m_stream->error()), message);
    m_listener.onMediaError(error);
}

void Mp4Reader::handleTrackData(const Mp4Track& track, const std::shared_ptr<MediaSampleBuffer>& sample)
{
    int trackId = track.getId();

    switch (track.getCodecBoxType()) {
    case MP4_avc1:
    case MP4_encv: {
        if (m_ceaCaptions) {
            m_ceaCaptions->fromMediaSampleBuffer(sample);
        }

        // TODO Keep frame as is and convert to annexB in platform when necessary
        if (m_platform.getCapabilities().avcFormat == AVCFormatType::AnnexB) {
            avcConvertToAnnexB(*m_formats[VideoTrackId], *sample);
        } else {
            sample->isSyncSample = avcContainsIDRSlice(sample->buffer);
        }

        if (m_trackSampleCounts[trackId] == 0 && !sample->isSyncSample) {
            m_log->warn("Fragment started on non-IDR frame");
        }

    } break;

    case MP4_wvtt: {
        std::string text;
        createVTTSample(sample->buffer, text);
        const uint8_t* cstr = reinterpret_cast<const uint8_t*>(text.c_str());
        sample->buffer.assign(cstr, cstr + text.length());
        sample->type = MP4_wvtt;
    } break;

    default:
        break;
    }

    m_samplesRead++;
    sample->isDiscontinuity = m_trackSampleCounts[trackId] == 0;
    m_trackSampleCounts[trackId]++;

    m_listener.onMediaSample(getStableTrackId(track), sample);
}

uint16_t Mp4Reader::readUint16(const uint8_t* data, size_t& offset)
{
    uint16_t value = (uint16_t)((data[offset] << 8) | data[offset + 1]);
    offset += 2;
    return value;
}

uint32_t Mp4Reader::readUint32(const uint8_t* data, size_t& offset)
{
    uint32_t value = (uint32_t)((data[offset] << 24) | (data[offset + 1] << 16) | (data[offset + 2] << 8) | data[offset + 3]);
    offset += 4;
    return value;
}

std::shared_ptr<MediaFormat> Mp4Reader::createTrackFormat(const Mp4Track& track)
{
    //m_log->info("format track: %d (%s) duration: %lld sample count %d",
    //    track.getId(),
    //    track.getHandlerName().c_str(),
    //    MediaTime(track.getDuration(), track.getTimescale()).microseconds().count(),
    //    track.getSampleCount());

    std::shared_ptr<MediaFormat> format;
    switch (track.getCodecBoxType()) {
    case MP4_avc1:
    case MP4_encv:
        format = createAVCFormat(track);
        break;

    case MP4_vp09:
        format = createVP9Format(track);
        break;

    case MP4_mp4a:
    case MP4_enca:
        format = createAACFormat(track);
        break;

    case MP4_wvtt:
        format = std::make_shared<SourceFormat>(MediaType::Text_VTT);
        break;

    default:
        uint32_t tt = track.getCodecBoxType();
        const uint8_t* c = (const uint8_t*)&tt;
        std::string name({ (char)c[0], (char)c[1], (char)c[2], (char)c[3] });
        m_log->warn("no format for track: (%s)", name.c_str());
    }

    if (format) {
        format->setPath(m_path);
        if (!m_parser.getPsshBytes().empty() && track.isProtected()) {
            format->setProtectionData(m_parser.getPsshBytes());
        }
    }

    return format;
}

bool Mp4Reader::avcContainsIDRSlice(const std::vector<uint8_t>& buffer)
{
    for (auto nalu : NalIterator(buffer, m_nalLengthSize)) {
        if (AVCParser::NalTypeIDR == nalu.type) {
            return true;
        }
    }

    return false;
}

MediaTime Mp4Reader::getBaseDecodeTime() const
{
    MediaTime baseMediaTime;
    for (const auto& track : m_parser.getTracks()) {
        baseMediaTime = std::max(baseMediaTime, MediaTime(track->getBaseMediaDecodeTime(), track->getTimescale()));
    }
    return baseMediaTime;
}

void Mp4Reader::readEmsgs()
{
    // twitch specific metadata
    if (!m_parser.getEmsgs().empty()) {
        MediaTime baseMediaTime = getBaseDecodeTime();
        for (const auto& emsg : m_parser.getEmsgs()) {
            if (emsg.scheme_id_uri == "urn:twitch:id3") {
                // content is ID3 binary like TS
                MediaTime timestamp = baseMediaTime + MediaTime(emsg.presentation_time_delta, emsg.timescale);
                auto frames = Id3::parseFrames(emsg.data, timestamp);
                createMetadataTrack();
                for (const auto& frame : frames) {
                    m_listener.onMediaSample(MediaReader::MetaTrackId, frame);
                }
            }
        }
    }
}

void Mp4Reader::createMetadataTrack()
{
    if (!m_formats.count(MediaReader::MetaTrackId)) {
        auto format = std::make_shared<SourceFormat>(MediaType::Text_Json);
        m_formats[MediaReader::MetaTrackId] = format;
        m_listener.onMediaTrack(MediaReader::MetaTrackId, format);
    }
}

void Mp4Reader::createVTTSample(const std::vector<uint8_t>& data, std::string& text)
{
    MemoryStream stream;
    stream.write(data.data(), data.size());
    Mp4Parser parser(m_log);
    parser.setStream(&stream);
    // convert the box into VTT sample lines e.g.
    // 1
    // 00:01.000 --> 00:04.000
    // - Caption
    parser.readBoxes(0, data.size(), [&text, &parser, &stream](mp4box& box) {
        switch (box.type) {
        case MP4_vttc:
        case MP4_vtte:
        case MP4_vttx:
            parser.readBoxes(box, [&](mp4box& inner) {
                switch (inner.type) {
                case MP4_iden:
                case MP4_sttg:
                case MP4_payl: {
                    size_t length = static_cast<size_t>(inner.size) - 8;
                    std::vector<uint8_t> buffer;
                    buffer.reserve(length);
                    stream.read(buffer.data(), length);
                    text += std::string(reinterpret_cast<const char*>(buffer.data()), length);
                } break;

                default:
                    break;
                }
                return true;
            });
            break;

        default:
            break;
        }
        return true;
    });
}

std::shared_ptr<MediaFormat> Mp4Reader::createAVCFormat(const Mp4Track& track)
{
    if (track.getCodecData().empty()) {
        m_listener.onMediaError(Error(ErrorSource::File, MediaResult::ErrorInvalidData, "Missing avc codec data"));
        return nullptr;
    }

    const auto& video = track.getVideoInfo();
    std::shared_ptr<SourceFormat> format = SourceFormat::createVideoFormat(
        MediaType::Video_AVC, video.width, video.height);
    const uint8_t* buffer = track.getCodecData().data();
    // read the avcC box inside the codec data
    size_t offset = 0;
    uint32_t size = readUint32(buffer, offset);
    uint32_t type = readUint32(buffer, offset);

    if (type != MP4_avcC) {
        m_listener.onMediaError(Error(ErrorSource::File, MediaResult::ErrorInvalidData, "No avcC data"));
        return nullptr;
    }

    std::vector<uint8_t> extradata(buffer + 8, buffer + size);
    auto ed = AVCParser::parseExtradata(extradata);
    if (!ed.sps.empty() && !ed.pps.empty()) {
        m_nalLengthSize = ed.lengthSize;
        format->setInt(MediaFormat::Video_AVC_NAL_LengthSize, ed.lengthSize);
        format->setInt(MediaFormat::Video_AVC_Profile, ed.profile);
        format->setInt(MediaFormat::Video_AVC_Level, ed.level);
        format->setCodecData(MediaFormat::Video_AVC_SPS, ed.sps[0]);
        format->setCodecData(MediaFormat::Video_AVC_PPS, ed.pps[0]);
        format->setCodecData(MediaFormat::Video_AVC_AVCC, extradata);
    } else {
        m_listener.onMediaError(Error(ErrorSource::File, MediaResult::ErrorInvalidData, "Invalid avc codec data"));
        return nullptr;
    }

    return format;
}

std::shared_ptr<MediaFormat> Mp4Reader::createVP9Format(const Mp4Track& track)
{
    const auto& video = track.getVideoInfo();
    std::shared_ptr<SourceFormat> format = SourceFormat::createVideoFormat(
        MediaType::Video_VP9, video.width, video.height);
    // so far current decoders don't need the VP9 configuration record, it's configured from the stream
    return format;
}

std::shared_ptr<MediaFormat> Mp4Reader::createAACFormat(const Mp4Track& track)
{
    if (track.getCodecData().empty()) {
        m_listener.onMediaError(Error(ErrorSource::File, MediaResult::ErrorInvalidData, "Invalid aac codec data"));
        return nullptr;
    }

    const mp4audio& audio = track.getAudioInfo();

    const uint8_t* buffer = track.getCodecData().data();
    // read the esds box inside the codec data
    size_t offset = 0;
    uint32_t size = readUint32(buffer, offset);
    uint32_t type = readUint32(buffer, offset);

    const uint8_t DescriptorTagStart = 0x80;

    if (type == MP4_esds) {
        int channelcount = audio.channelcount;

        std::vector<uint8_t> esds;
        readUint32(buffer, offset); // skip version & flags
        uint8_t esdsType = buffer[offset++];
        if (buffer[offset] == DescriptorTagStart) {
            offset += 3; // extended descriptor type tag
        }
        uint8_t length = buffer[offset++];

        if (esdsType == 0x03 && offset < size) {
            uint16_t esid = readUint16(buffer, offset); // esid
            uint8_t streamPriority = buffer[offset++]; // streamPriority
            esdsType = buffer[offset++];

            (void)esid;
            (void)streamPriority;

            if (esdsType == 0x04 && offset < size) {
                if (buffer[offset] == DescriptorTagStart) {
                    offset += 3; // extended descriptor type tag
                }
                length = buffer[offset++];
                uint8_t objectId = buffer[offset++];
                uint8_t streamType = buffer[offset++];
                offset += 3; // buffer size
                uint32_t maxBitrate = readUint32(buffer, offset);
                uint32_t avgBitrate = readUint32(buffer, offset);
                esdsType = buffer[offset++];
                m_log->info("ES header id %d type %d, max br %d avg br %d",
                    objectId, streamType, maxBitrate, avgBitrate);

                if (esdsType == 0x05) {
                    if (buffer[offset] == DescriptorTagStart) {
                        offset += 3; // extended descriptor type tag
                    }
                    uint8_t subLength = buffer[offset++];
                    esds.assign(buffer + offset, buffer + offset + subLength);

                    uint8_t objectType = (esds[0] & 0xF8) >> 3;
                    uint8_t frequencyIndex = ((esds[0] & 0x07) << 1) | ((esds[1] & 0x80) >> 7);
                    uint8_t channelConfig = (esds[1] & 0x78) >> 3;
                    if (objectType == ESDSObjectType::AAC_LC
                        || objectType == ESDSObjectType::AAC_LTP
                        || objectType == ESDSObjectType::AAC_Main) {
                        m_log->debug("objectType %d frequencyIndex %d channelConfig %d", objectType,
                            frequencyIndex, channelConfig);
                        channelcount = channelConfig;
                    }
                }
            }
        }

        std::shared_ptr<SourceFormat> format = SourceFormat::createAudioFormat(
            MediaType::Audio_AAC,
            channelcount,
            audio.samplerate,
            audio.samplesize);
        format->setCodecData(MediaFormat::Audio_AAC_ESDS, esds);
        offset += length;
        return format;
    } else {
        m_listener.onMediaError(Error(ErrorSource::File, MediaResult::ErrorInvalidData, "Missing esds"));
        return nullptr;
    }
}

MediaReader::TrackId Mp4Reader::getStableTrackId(const Mp4Track& track)
{
    // if using demuxed or mp4s with varying track ids map them on to stable ids based on the type
    switch (track.getHandlerType()) {
    case MP4_vide:
        return MediaReader::VideoTrackId;
    default:
    case MP4_soun:
        return MediaReader::AudioTrackId;
    case MP4_text:
        return MediaReader::TextTrackId;
    case MP4_meta:
        return MediaReader::MetaTrackId;
    }
}

void Mp4Reader::avcConvertToAnnexB(const MediaFormat& format, MediaSampleBuffer& sample)
{
    if (m_nalLengthSize != 4) {
        m_listener.onMediaError(Error(ErrorSource::File, MediaResult::ErrorNotSupported,
            "Unsupported nal length size " + std::to_string(m_nalLengthSize)));
        return;
    }

    uint8_t* data = sample.buffer.data();
    size_t offset = 0;
    size_t idrOffset = 0;
    size_t index = 0;
    int spsCount = 0;
    int ppsCount = 0;
    bool hasIDR = false;
    while (offset < sample.buffer.size()) {
        uint32_t nalSize = readUint32(data, offset);
        offset -= 4;
        data[offset++] = 0x00;
        data[offset++] = 0x00;
        data[offset++] = 0x00;
        data[offset++] = 0x01;
        int nalType = (data[offset] & 0x1f);
        if (nalType == AVCParser::NalTypeIDR && !hasIDR) {
            idrOffset = offset;
            hasIDR = true;
        } else if (nalType == AVCParser::NalTypeSPS) {
            spsCount++;
        } else if (nalType == AVCParser::NalTypePPS) {
            ppsCount++;
        }

        offset += nalSize;
        index++;
    }

    // add sps/pps before idr (required for adaptive configuration changes)
    if (hasIDR && spsCount == 0 && ppsCount == 0) {
        const auto& sps = format.getCodecData(MediaFormat::Video_AVC_SPS);
        const auto& pps = format.getCodecData(MediaFormat::Video_AVC_PPS);
        std::vector<uint8_t> parameterSets;
        sample.buffer.reserve(sample.buffer.size() + m_nalLengthSize * 2 + sps.size() + pps.size());
        std::vector<uint8_t> prefix(m_nalLengthSize, 0);
        prefix[prefix.size() - 1] = 0x1;
        parameterSets.insert(parameterSets.end(), prefix.begin(), prefix.end());
        parameterSets.insert(parameterSets.end(), sps.begin(), sps.end());
        parameterSets.insert(parameterSets.end(), prefix.begin(), prefix.end());
        parameterSets.insert(parameterSets.end(), pps.begin(), pps.end());
        sample.buffer.insert(sample.buffer.begin() + idrOffset - m_nalLengthSize, parameterSets.begin(), parameterSets.end());

        if (sample.type == MediaSample::Type::EncryptedMemoryBuffer) {
            // this assumes the location the SPS/PPS and adjusts the bytes of clear data
            SecureSampleBuffer& secureSample = static_cast<SecureSampleBuffer&>(sample);
            if (!secureSample.subsampleRange.empty()) {
                secureSample.subsampleRange[0].first += static_cast<uint16_t>(parameterSets.size());
            }
        }
    }

    sample.isSyncSample = hasIDR;
}
}
}
