//
//  vt_decoder.cpp
//  player-core
//
//  Created by Purushe, Nikhil on 11/1/16.
//
//

#include "VTDecoder.hpp"

namespace twitch {
VTDecoder::VTDecoder(std::shared_ptr<Log> log)
    : m_log(log)
    , m_formatDescription(NULL)
    , m_outputDescription(NULL)
    , m_decompressionSession(NULL)
    , m_decodeStatus(noErr)
    , m_droppedFrames(0)
{
}

VTDecoder::~VTDecoder()
{
    releaseSession();
    releaseFormatDescription(m_formatDescription);
    releaseFormatDescription(m_outputDescription);
}

MediaResult VTDecoder::configure(const MediaFormat& input, MediaFormat& output)
{
    int width = input.getInt(MediaFormat::Video_Width);
    int height = input.getInt(MediaFormat::Video_Height);
    int nalLengthSize = input.getInt(MediaFormat::Video_AVC_NAL_LengthSize);

    releaseFormatDescription(m_formatDescription);
    OSStatus status;

    // configure with PPS/SPS
    if (input.hasCodecData(MediaFormat::Video_AVC_PPS)
        && input.hasCodecData(MediaFormat::Video_AVC_SPS)) {

        auto& pps = input.getCodecData(MediaFormat::Video_AVC_PPS);
        auto& sps = input.getCodecData(MediaFormat::Video_AVC_SPS);
        const int ParameterSets = 2;
        uint8_t* pointers[ParameterSets];
        size_t sizes[ParameterSets];
        pointers[0] = const_cast<uint8_t*>(pps.data());
        pointers[1] = const_cast<uint8_t*>(sps.data());
        sizes[0] = pps.size();
        sizes[1] = sps.size();

        status = CMVideoFormatDescriptionCreateFromH264ParameterSets(
            kCFAllocatorDefault, ParameterSets, pointers, sizes, nalLengthSize, &m_formatDescription);

        if (status != noErr) {
            return check(status, "CMVideoFormatDescriptionCreateFromH264ParameterSets");
        }
    } else {
        status = CMVideoFormatDescriptionCreate(
            kCFAllocatorDefault, kCMVideoCodecType_H264, width, height, NULL, &m_formatDescription);

        if (status != noErr) {
            return check(status, "CMVideoFormatDescriptionCreate");
        }
    }

    output.setType(MediaType("video", "CVImageBuffer"));
    output.setInt(MediaFormat::Video_Width, width);
    output.setInt(MediaFormat::Video_Height, height);

    // can we reuse the current session?
    if (m_decompressionSession) {
        if (VTDecompressionSessionCanAcceptFormatDescription(m_decompressionSession, m_formatDescription)) {
            m_log->info("VTDecompressionSessionCanAcceptFormatDescription accepted reusing session");
            return MediaResult::Ok;
        } else {
            // must wait for async frames to complete before releasing the session
            OSStatus status = VTDecompressionSessionWaitForAsynchronousFrames(m_decompressionSession);
            if (status == kVTInvalidSessionErr) {
                m_log->info("VTDecompressionSessionWaitForAsynchronousFrames kVTInvalidSessionErr");
            } else if (status != noErr) {
                return check(status, "VTDecompressionSessionWaitForAsynchronousFrames");
            }
            releaseSession();
        }
    }

    // else (re)create the session
    VTDecompressionOutputCallbackRecord callbackRecord;
    callbackRecord.decompressionOutputCallback = &decompressionOutputCallback;
    callbackRecord.decompressionOutputRefCon = this;

    status = VTDecompressionSessionCreate(
        kCFAllocatorDefault, m_formatDescription, NULL, NULL, &callbackRecord, &m_decompressionSession);

    if (status != noErr) {
        return check(status, "VTDecompressionSessionCreate");
    }

    if (!VTDecompressionSessionCanAcceptFormatDescription(m_decompressionSession, m_formatDescription)) {
        m_log->info("VTDecompressionSessionCanAcceptFormatDescription not accepted after creation");
        return MediaResult::ErrorNotSupported;
    }

    return MediaResult::Ok;
}

MediaResult VTDecoder::decode(const MediaSampleBuffer& input)
{
    if (!m_decompressionSession) {
        m_log->warn("Invalid VTDecompressionSession dropping frame, waiting for configure");
        m_droppedFrames++;
        return MediaResult::Ok;
    }

    if (m_decodeStatus != noErr) {
        return check(m_decodeStatus, "Decode error in decompressionOutputCallback");
    }

    const uint8_t* buffer = input.buffer.data();
    size_t bufferSize = input.buffer.size();

    // create the sample buffer
    CMBlockBufferRef blockBuffer;
    OSStatus status = CMBlockBufferCreateWithMemoryBlock(
        kCFAllocatorDefault, (void*)(buffer), bufferSize, kCFAllocatorNull, NULL, 0, bufferSize, 0, &blockBuffer);

    MediaResult result = MediaResult::Ok;
    if ((result = check(status, "CMBlockBufferCreateWithMemoryBlock")) != MediaResult::Ok) {
        return result;
    }

    CMSampleTimingInfo timingInfo {};
    MediaTime duration = input.duration;
    MediaTime decodeTime = input.decodeTime;
    MediaTime presentationTime = input.presentationTime;
    timingInfo.duration = CMTimeMake(duration.count(), duration.timebase());
    timingInfo.decodeTimeStamp = CMTimeMake(decodeTime.count(), decodeTime.timebase());
    timingInfo.presentationTimeStamp = CMTimeMake(presentationTime.count(), presentationTime.timebase());

    CMSampleBufferRef sampleBuffer = NULL;
    status = CMSampleBufferCreate(
        kCFAllocatorDefault, blockBuffer, true, NULL, NULL, m_formatDescription, 1, 1, &timingInfo, 1, &bufferSize, &sampleBuffer);

    if ((result = check(status, "CMSampleBufferCreate")) != MediaResult::Ok) {
        return result;
    }

    VTDecodeFrameFlags decodeFlags = kVTDecodeFrame_EnableAsynchronousDecompression | kVTDecodeFrame_EnableTemporalProcessing;

    // add don't output frame flag for decode only frames
    if (input.isDecodeOnly) {
        decodeFlags |= kVTDecodeFrame_DoNotOutputFrame;
    }

    VTDecodeInfoFlags outputFlags;
    status = VTDecompressionSessionDecodeFrame(m_decompressionSession, sampleBuffer, decodeFlags, NULL, &outputFlags);

    if (blockBuffer) {
        CFRelease(blockBuffer);
    }

    if (sampleBuffer) {
        CMSampleBufferInvalidate(sampleBuffer);
        CFRelease(sampleBuffer);
    }

    if (status == kVTInvalidSessionErr) {
        m_log->warn("kVTInvalidSessionErr releasing session");
        releaseSession();
    }

    return check(status, "VTDecompressionSessionDecodeFrame");
}

MediaResult VTDecoder::hasOutput(bool& hasOutput)
{
    std::unique_lock<std::mutex> lock(m_outputMutex);
    hasOutput = m_outputBuffers.size();
    return MediaResult::Ok;
}

MediaResult VTDecoder::getOutput(std::shared_ptr<MediaSample>& output)
{
    std::unique_lock<std::mutex> lock(m_outputMutex);
    if (!m_outputBuffers.empty()) {
        output = m_outputBuffers.front();
        m_outputBuffers.pop();
        return MediaResult::Ok;
    } else {
        return MediaResult::Error;
    }
}

MediaResult VTDecoder::flush()
{
    m_droppedFrames = 0;

    if (!m_decompressionSession) {
        return MediaResult::ErrorInvalidState;
    }

    OSStatus status = VTDecompressionSessionFinishDelayedFrames(m_decompressionSession);
    return check(status, "VTDecompressionSessionFinishDelayedFrames");
}

MediaResult VTDecoder::reset()
{
    m_droppedFrames = 0;

    if (m_decompressionSession) {
        OSStatus status = VTDecompressionSessionFinishDelayedFrames(m_decompressionSession);
        if (status == noErr || status == kVTInvalidSessionErr) {
            std::unique_lock<std::mutex> lock(m_outputMutex);
            m_outputBuffers = std::queue<std::shared_ptr<CMSampleBufferSample>>();
        } else {
            return check(status, "VTDecompressionSessionFinishDelayedFrames");
        }

        // need to release the session here otherwise the decoder may error even if the session is reusable
        releaseSession();
    }

    return MediaResult::Ok;
}

MediaResult VTDecoder::check(OSStatus status, const char* message)
{
    if (status != noErr) {
        m_log->error("%s error %d", message, status);
        return MediaResult(MediaResult::Error, status);
    }

    return MediaResult::Ok;
}

void VTDecoder::decompressionOutput(OSStatus status, VTDecodeInfoFlags infoFlags,
    CVImageBufferRef imageBuffer,
    CMTime presentationTimeStamp,
    CMTime presentationDuration)
{
    m_decodeStatus = status;
    check(status, "decompressionOutputCallback");

    if (infoFlags & kVTDecodeInfo_FrameDropped) {
        m_droppedFrames++;
    }
    if (status != noErr || !imageBuffer) {
        return;
    }

    // check output format changed
    if (m_outputDescription && !CMVideoFormatDescriptionMatchesImageBuffer(m_outputDescription, imageBuffer)) {
        releaseFormatDescription(m_outputDescription);
    }

    if (!m_outputDescription) {
        OSStatus status = CMVideoFormatDescriptionCreateForImageBuffer(
            kCFAllocatorDefault, imageBuffer, &m_outputDescription);
        if (status != noErr) {
            check(status, "CMVideoFormatDescriptionCreateForImageBuffer failed");
            m_decodeStatus = status;
            return;
        }
    }

    // create the sample buffer from the image buffer
    CMSampleTimingInfo timingInfo;
    timingInfo.decodeTimeStamp = kCMTimeInvalid;
    timingInfo.presentationTimeStamp = presentationTimeStamp;
    timingInfo.duration = kCMTimeInvalid; // not needed use PTS values only
    CMSampleBufferRef sampleBuffer = NULL;

    status = CMSampleBufferCreateReadyWithImageBuffer(
        kCFAllocatorDefault, imageBuffer, m_outputDescription, &timingInfo, &sampleBuffer);
    if (status != noErr) {
        check(status, "CMSampleBufferCreateReadyWithImageBuffer");
        m_decodeStatus = status;
        return;
    }

    auto sample = std::make_shared<CMSampleBufferSample>(sampleBuffer, presentationTimeStamp, presentationDuration);

    {
        std::unique_lock<std::mutex> lock(m_outputMutex);
        m_outputBuffers.push(sample);
    }
}

void VTDecoder::releaseFormatDescription(CMVideoFormatDescriptionRef& formatDescription)
{
    if (formatDescription) {
        CFRelease(formatDescription);
        formatDescription = NULL;
    }
}

void VTDecoder::releaseSession()
{
    if (m_decompressionSession) {
        VTDecompressionSessionInvalidate(m_decompressionSession);
        CFRelease(m_decompressionSession);
        m_decompressionSession = NULL;
    }
}

void VTDecoder::decompressionOutputCallback(void* decompressionOutputRefCon,
    void* sourceFrameRefCon,
    OSStatus status,
    VTDecodeInfoFlags infoFlags,
    CVImageBufferRef imageBuffer,
    CMTime presentationTimeStamp,
    CMTime presentationDuration)
{
    (void)sourceFrameRefCon;
    if (!decompressionOutputRefCon) {
        return;
    }

    VTDecoder* decoder = reinterpret_cast<VTDecoder*>(decompressionOutputRefCon);
    decoder->decompressionOutput(status, infoFlags, imageBuffer, presentationTimeStamp, presentationDuration);
}
}
