#include "fmp4track.hpp"
#include "../avc/avcutil.hpp"
#include "fmp4boxes.hpp"
#include "util/Base64.hpp"
#include <algorithm>
#include <cassert>

namespace twitch {
namespace media {
fmp4track_aac::fmp4track_aac(uint32_t trackId, uint32_t timescale, uint32_t defaultSampleSize,
    uint32_t defaultSampleFlags,
    uint32_t defaultSampleDuration,
    int32_t elstMediaTime,
    const std::vector<uint8_t>& extradata,
    const std::vector<EncryptionInfo>& encryptionInfo)
    : Mp4Track(trackId, MP4_mp4a, timescale, MP4_soun,
        defaultSampleSize, defaultSampleFlags, defaultSampleDuration,
        elstMediaTime, "SoundHandler", extradata, encryptionInfo)
{
    adts_t adts {};

    if (aac_parse_extradata(&adts, extradata.data(), extradata.size())) {
        m_audio.samplesize = 16;
        m_audio.channelcount = static_cast<int16_t>(adts.channelConfig);
        m_audio.samplerate = aac_frequency(&adts);
    } else {
        m_audio.samplesize = 16;
        m_audio.channelcount = 0;
        m_audio.samplerate = 0;
    }

#ifdef DRM_TESTING
    if (encryptionInfo.empty()) {
        auto enc = fmp4drmFactory::clearKeyTestInfo();
        fprintf(stderr, "DRM_TESTING key: %s\n", twitch::Base64::encode(enc[0].key.data(), enc[0].key.size()).c_str());
        setEncryptionInfo(enc);
    }
#endif
}

bool fmp4track_aac::addSample(int64_t dts, int32_t cts, uint32_t duration, uint32_t flags, const uint8_t* data, uint32_t size)
{
    if (!isProtected()) {
        return Mp4Track::addSample(dts, cts, duration, flags, data, size);
    }

    auto& encInfo = m_encryptionInfo[m_currentEncryptionIndex];
    mp4sample sample = mp4sample();
    sample.decodeTime = dts;
    sample.compositionTimeOffset = cts;
    sample.duration = duration;
    sample.flags = flags;
    sample.size = size;
    sample.initializationVector = encInfo.iv;
    sample.sampleGroup = static_cast<uint32_t>(m_currentEncryptionIndex + 1);
    auto drm = Mp4DrmFactory::configure(encInfo.scheme, encInfo.key, encInfo.iv);

    if (!drm) {
        return false;
    }

    std::vector<uint8_t> sampleData;
    sampleData.reserve(size);

    while (AesBlockSize <= size) {
        auto protectedData = drm->encode(data, AesBlockSize);
        sampleData.insert(sampleData.end(), protectedData.begin(), protectedData.end());
        data += AesBlockSize, size -= AesBlockSize;
    }

    // CTR can encrypt partial blocks
    if (0 < size && fourcc("cenc") == encInfo.scheme) {
        auto protectedData = drm->encode(data, size);
        sampleData.insert(sampleData.end(), protectedData.begin(), protectedData.end());
        data += size, size -= size;
    }

    encInfo.incrementIv();
    // Anything left is in the clear
    sampleData.insert(sampleData.end(), data, data + size);
    assert(sample.size == sampleData.size());
    sample.size = static_cast<uint32_t>(sampleData.size());
    return Mp4Track::addSample(sample, sampleData.data());
}

////////////////////////////////////////////////////////////////////////////////
fmp4track_avc::fmp4track_avc(uint32_t trackId, uint32_t timescale,
    uint32_t defaultSampleSize, uint32_t defaultSampleFlags, uint32_t defaultSampleDuration,
    int32_t elstMediaTime,
    const std::vector<uint8_t>& extradata,
    const std::vector<EncryptionInfo>& encryptionInfo)
    : Mp4Track(trackId, MP4_avc1, timescale, MP4_vide,
        defaultSampleSize, defaultSampleFlags, defaultSampleDuration,
        elstMediaTime, "VideoHandler", extradata, encryptionInfo)
    , m_lengthSize(4)
{
    clearSample();
    auto ed = AVCParser::parseExtradata(extradata);
    if (!ed.sps.empty()) {
        auto sps = AVCParser::parseSps(ed.sps[0]);
        auto res = sps.resolution();
        m_lengthSize = ed.lengthSize;
        m_video.width = static_cast<uint16_t>(res.first);
        m_video.height = static_cast<uint16_t>(res.second);
    } else {
        m_video.width = 0;
        m_video.height = 0;
    }

    for (const auto& nalu : ed.sps) {
        m_avcParser.parseNalu(nalu);
    }

    for (const auto& nalu : ed.pps) {
        m_avcParser.parseNalu(nalu);
    }

#ifdef DRM_TESTING
    if (encryptionInfo.empty()) {
        auto enc = fmp4drmFactory::clearKeyTestInfo();
        fprintf(stderr, "DRM_TESTING key: %s\n", twitch::Base64::encode(enc[0].key.data(), enc[0].key.size()).c_str());
        setEncryptionInfo(enc);
    }
#endif
}

void fmp4track_avc::clearSample()
{
    m_bytesOfClearData = 0;
    m_sample = mp4sample();
    m_sampleData.clear();
}

void fmp4track_avc::startSample(int64_t dts, int32_t cts, uint32_t duration, uint32_t flags)
{
    clearSample();
    auto& encInfo = m_encryptionInfo[m_currentEncryptionIndex];
    m_sample = mp4sample();
    m_sample.decodeTime = dts;
    m_sample.compositionTimeOffset = cts;
    m_sample.duration = duration;
    m_sample.flags = flags;
    m_sample.size = 0;
    m_sample.initializationVector = encInfo.iv;
    m_sample.sampleGroup = static_cast<uint32_t>(m_currentEncryptionIndex + 1);
    m_drm = Mp4DrmFactory::configure(encInfo.scheme, encInfo.key, encInfo.iv);
}

void fmp4track_avc::updateSubsampleInfo(size_t bytesOfProtectedData)
{
    // protection against large paddings
    if (m_bytesOfClearData || bytesOfProtectedData) {
        const auto MaxUint16 = std::numeric_limits<uint16_t>::max();
        while (static_cast<size_t>(MaxUint16) < m_bytesOfClearData) {
            m_sample.subsampleRange.emplace_back(static_cast<uint16_t>(MaxUint16), 0);
            m_bytesOfClearData -= MaxUint16;
        }

        m_sample.subsampleRange.emplace_back(static_cast<uint16_t>(m_bytesOfClearData), static_cast<uint32_t>(bytesOfProtectedData));
        m_bytesOfClearData -= m_bytesOfClearData;
    }
}

void fmp4track_avc::addNalu(const uint8_t* data, size_t size)
{
    int8_t nalu_type = data[0] & 0x1F;
    auto& encInfo = m_encryptionInfo[m_currentEncryptionIndex];

    if (!m_drm) {
        return;
    }

    assert(4 == m_lengthSize);
    m_sampleData.reserve(m_sampleData.size() + m_lengthSize + size);
    m_sampleData.push_back(static_cast<uint8_t>(size >> 24));
    m_sampleData.push_back(static_cast<uint8_t>(size >> 16));
    m_sampleData.push_back(static_cast<uint8_t>(size >> 8));
    m_sampleData.push_back(static_cast<uint8_t>(size));
    m_bytesOfClearData += m_lengthSize;

    if (AVCParser::NalTypeIDR == nalu_type || AVCParser::NalTypeSlice == nalu_type) {
        m_drm->startSubSample();
        size_t headerSize = m_avcParser.parseNalu(data, size);
        assert(16 > headerSize); // 16 byte header is likely bad data

        if (fourcc("cenc") == encInfo.scheme) {
            // For CENC round to AesBlockSize
            // TODO CTR can encode partial blocks, but would need more complicated loop logic below
            size_t blocksOfProtectedData = (size - headerSize) >> 4;
            headerSize = size - (blocksOfProtectedData * AesBlockSize);
        }

        m_sampleData.insert(m_sampleData.end(), data, data + headerSize);
        m_bytesOfClearData += headerSize, data += headerSize, size -= headerSize;

        size_t bytesOfProtectedData = 0;
        bool patternEncryption = !!encInfo.patternEncryption(); // patternEncryption is hardcoded to 0x00 or 0x19
        uint8_t cryptByteBlock = encInfo.cryptByteBlock();
        uint8_t skipByteBlockMod = cryptByteBlock + encInfo.skipByteBlock();

        // TODO make seperate functions for cenc/cbcs?
        for (size_t block = 0; 0 < size; ++block) {
            if (AesBlockSize > size || (patternEncryption && (block % skipByteBlockMod) >= cryptByteBlock)) {
                // skip_byte_block // (for cens)
                size_t blockSize = std::min(AesBlockSize, size);
                m_sampleData.insert(m_sampleData.end(), data, data + blockSize);
                bytesOfProtectedData += blockSize, data += blockSize, size -= blockSize;
            } else {
                // crypt_byte_block
                auto protectedData = m_drm->encode(data, AesBlockSize);
                m_sampleData.insert(m_sampleData.end(), protectedData.begin(), protectedData.end());
                bytesOfProtectedData += AesBlockSize, data += AesBlockSize, size -= AesBlockSize;
            }
        }

        m_drm->endSubSample();
        updateSubsampleInfo(bytesOfProtectedData);
    }

    m_sampleData.insert(m_sampleData.end(), data, data + size);
    m_bytesOfClearData += size;
}

bool fmp4track_avc::finishSample()
{
    m_drm.reset();
    auto& encInfo = m_encryptionInfo[m_currentEncryptionIndex];
    encInfo.incrementIv();
    updateSubsampleInfo(0);
    m_sample.size = static_cast<uint32_t>(m_sampleData.size());
    return Mp4Track::addSample(m_sample, m_sampleData.data());
}

bool fmp4track_avc::addSample(int64_t dts, int32_t cts, uint32_t duration, uint32_t flags, const uint8_t* data, uint32_t size)
{
    if (!data || !size) {
        return false;
    }

    if (!isProtected()) {
        return Mp4Track::addSample(dts, cts, duration, flags, data, size);
    }

    startSample(dts, cts, duration, flags);
    for (const auto& nal : NalIterator(data, size)) {
        addNalu(nal.data, nal.size);
    }

    return finishSample();
}

fmp4track_vp9::fmp4track_vp9(uint32_t trackId, uint32_t timescale,
    uint32_t defaultSampleSize, uint32_t defaultSampleFlags, uint32_t defaultSampleDuration,
    int32_t elstMediaTime,
    const VP9ConfigurationRecord& configurationRecord)
    : Mp4Track(trackId, MP4_vp09, timescale, MP4_vide,
        defaultSampleSize, defaultSampleFlags, defaultSampleDuration,
        elstMediaTime, "VideoHandler", std::vector<uint8_t>(),
        std::vector<EncryptionInfo>())
{
    m_codecData.reserve(8);
    m_codecData.push_back(configurationRecord.profile);
    m_codecData.push_back(configurationRecord.level);
    // 4 bits bitDepth, 3 bits chromaSubsampling, 1 bit videoFullRangeFlag
    uint8_t next_byte = (configurationRecord.bitDepth << 4)
        | (configurationRecord.chromaSubsampling << 1)
        | (configurationRecord.videoFullRangeFlag);
    m_codecData.push_back(next_byte);
    m_codecData.push_back(configurationRecord.colourPrimaries);
    m_codecData.push_back(configurationRecord.transferCharacteristics);
    m_codecData.push_back(configurationRecord.matrixCoefficients);

    //codecInitializationDataSize 0 for VP9 2 bytes = 0
    m_codecData.push_back(static_cast<uint8_t>(0));
    m_codecData.push_back(static_cast<uint8_t>(0));

    m_video.width = static_cast<uint16_t>(configurationRecord.width);
    m_video.height = static_cast<uint16_t>(configurationRecord.height);
    m_configurationRecord = configurationRecord;
}

fmp4track_webvtt::fmp4track_webvtt(uint32_t trackId, uint32_t timescale,
    uint32_t defaultSampleSize, uint32_t defaultSampleFlags, uint32_t defaultSampleDuration,
    int32_t elstMediaTime)
    : Mp4Track(trackId, MP4_wvtt, timescale, MP4_text,
        defaultSampleSize, defaultSampleFlags, defaultSampleDuration,
        elstMediaTime, "TextHandler", std::vector<uint8_t>(),
        std::vector<EncryptionInfo>())
{
}

fmp4track_opus::fmp4track_opus(uint32_t trackId, uint32_t timescale,
    uint32_t defaultSampleSize, uint32_t defaultSampleFlags, uint32_t defaultSampleDuration,
    int32_t elstMediaTime,
    const OpusConfiguration& opusConfig)
    : Mp4Track(trackId, MP4_opus, timescale, MP4_soun,
        defaultSampleSize, defaultSampleFlags, defaultSampleDuration,
        elstMediaTime, "SoundHandler", std::vector<uint8_t>(),
        std::vector<EncryptionInfo>())
{
    if (opusConfig.outputChannelCount > 0) {
        m_audio.samplesize = 16;
        m_audio.channelcount = opusConfig.outputChannelCount;
        m_audio.samplerate = 48000; // opus output is always decodable at 48khz
    } else {
        m_audio.samplesize = 16;
        m_audio.channelcount = 0;
        m_audio.samplerate = 0;
    }

    m_codecData.reserve(11);
    m_codecData.push_back(opusConfig.version);
    m_codecData.push_back(opusConfig.outputChannelCount);

    m_codecData.push_back((opusConfig.preSkip >> 8) & 0xFF);
    m_codecData.push_back(opusConfig.preSkip & 0xFF);

    for (int i = 3; i > -1; i--) {
        m_codecData.push_back((opusConfig.inputSampleRate >> (i * 8)) & 0xFF);
    }
    m_codecData.push_back((opusConfig.outputGain >> 8) & 0xFF);
    m_codecData.push_back(opusConfig.outputGain & 0xFF);
    m_codecData.push_back(opusConfig.channelMappingFamily);

    m_opusConfig = opusConfig;
}
}
}
