#include "pch.h"
#include "VideoDecoder.hpp"
#include "media/avc/avcutil.hpp"
#include "playercore/platform/windows/WindowsPlatform.hpp"
#include "VideoSample.hpp"
#include "debug/trace.hpp"
#include "checkhr.hpp"

using namespace twitch;
using namespace twitch::windows;
using namespace Microsoft;
using namespace Microsoft::WRL;

int roundUp(int numToRound, int multiple)
{
    if (multiple == 0)
        return numToRound;

    int remainder = numToRound % multiple;
    if (remainder == 0)
        return numToRound;

    return numToRound + multiple - remainder;
}

VideoDecoder::VideoDecoder(Microsoft::WRL::ComPtr<ID3D11Device> device, Microsoft::WRL::ComPtr<IMFDXGIDeviceManager> dxgiManager, bool isRunningOnXBox)
    : m_d3dDevice(device)
    , m_dxgiManager(dxgiManager)
    , m_isRunningOnXBox(isRunningOnXBox)
{
    assert(m_d3dDevice.Get());
    assert(m_dxgiManager.Get());

    UINT32 directX11supported = 0;

    if (SUCCEEDED(CoCreateInstance(CLSID_MSH264DecoderMFT, nullptr, CLSCTX_INPROC_SERVER, IID_IMFTransform, (void**)&m_decoderTransform))) {
        IMFAttributes * attributes = nullptr;
        m_decoderTransform->GetAttributes(&attributes);

        attributes->GetUINT32(MF_SA_D3D11_AWARE, &directX11supported);
        attributes->SetUINT32(MF_SA_MINIMUM_OUTPUT_SAMPLE_COUNT_PROGRESSIVE, MaxWorkerSamples);
        attributes->SetUINT32(MF_SA_MINIMUM_OUTPUT_SAMPLE_COUNT, MaxWorkerSamples);

        if (attributes != nullptr) {
            attributes->Release();
        }
    }

    m_softwareDecoder = (directX11supported == 0);

    // Force software decoding on Windows desktop since the behavior behind is different and it runs fine on this platform.
    if (!m_isRunningOnXBox) {
        m_softwareDecoder = true;
    }

    ComPtr<ID3D11DeviceContext> immediateContext;
    m_d3dDevice->GetImmediateContext(&immediateContext);
    immediateContext.CopyTo(m_videoContext.GetAddressOf());

    HRESULT hr = m_decoderTransform->ProcessMessage(MFT_MESSAGE_SET_D3D_MANAGER, ULONG_PTR(m_dxgiManager.Get()));
    if (fallbackInSoftwareModeIfFailed(hr, "MFT_MESSAGE_SET_D3D_MANAGER")) {
        return;
    }

    HANDLE handle = nullptr;
    hr = m_dxgiManager->OpenDeviceHandle(&handle);
    if (fallbackInSoftwareModeIfFailed(hr, "OpenDeviceHandle")) {
        return;
    }

    hr = m_dxgiManager->GetVideoService(handle, IID_ID3D11VideoDevice, (void**)&m_videoService);
    if (fallbackInSoftwareModeIfFailed(hr, "GetVideoService")) {
        m_dxgiManager->CloseDeviceHandle(handle);
        return;
    }

    static const GUID DXVA2_ModeH264_E = {
        0x1b81be68, 0xa0c7, 0x11d3,{ 0xb9, 0x84, 0x00, 0xc0, 0x4f, 0x2e, 0x73, 0xc5 }
    };

    static const GUID DXVA2_Intel_ModeH264_E = {
        0x604F8E68, 0x4951, 0x4c54,{ 0x88, 0xFE, 0xAB, 0xD2, 0x5C, 0x15, 0xB3, 0xD6 }
    };

    bool found = false;
    UINT profileCount = m_videoService->GetVideoDecoderProfileCount();
    for (UINT i = 0; i < profileCount; i++) {
        GUID id;
        hr = m_videoService->GetVideoDecoderProfile(i, &id);

        if (SUCCEEDED(hr) && (id == D3D11_DECODER_PROFILE_H264_VLD_NOFGT)) {
            m_decoderGUID = id;
            found = true;
            break;
        }

        if (SUCCEEDED(hr) && (id == DXVA2_ModeH264_E || id == DXVA2_Intel_ModeH264_E)) {
            m_decoderGUID = id;
            found = true;
            break;
        }
    }

    if (!found) {
        TRACE_ERROR("Profile not found");
        m_softwareDecoder = true;
    }

    BOOL nv12Support = false;
    hr = m_videoService->CheckVideoDecoderFormat(&m_decoderGUID, DXGI_FORMAT_NV12, &nv12Support);
    if (fallbackInSoftwareModeIfFailed(hr, "CheckVideoDecoderFormat")) {
        m_dxgiManager->CloseDeviceHandle(handle);
        return;
    }

    hr = m_dxgiManager->CloseDeviceHandle(handle);
    if (fallbackInSoftwareModeIfFailed(hr, "CloseDeviceHandle")) {
        return;
    }

    D3D11_VIDEO_DECODER_DESC desc;
    desc.Guid = m_decoderGUID;
    desc.SampleWidth = MaxDecodeWidth;
    desc.SampleHeight = MaxDecodeHeight;
    desc.OutputFormat = DXGI_FORMAT_NV12;

    UINT configCount = 0;
    hr = m_videoService->GetVideoDecoderConfigCount(&desc, &configCount);
    if (fallbackInSoftwareModeIfFailed(hr, "GetVideoDecoderConfigCount")) {
        return;
    }

    Microsoft::WRL::ComPtr<ID3D11VideoDecoder> decoder;

    for (UINT i = 0; i < configCount; i++) {
        D3D11_VIDEO_DECODER_CONFIG config;
        hr = m_videoService->GetVideoDecoderConfig(&desc, i, &config);
        if (SUCCEEDED(hr)) {
            hr = m_videoService->CreateVideoDecoder(&desc, &config, &decoder);
            if (SUCCEEDED(hr) && decoder) {
                break;
            }
        }
    }

    if (fallbackInSoftwareModeIfFailed(static_cast<HRESULT>(decoder == nullptr), "CreateVideoDecoder")) {
        return;
    }

    m_decoder = decoder;

    Microsoft::WRL::ComPtr<IMFMediaType> mediaType;
    hr = MFCreateMediaType(mediaType.GetAddressOf());
    if (fallbackInSoftwareModeIfFailed(hr, "MFCreateMediaType")) {
        return;
    }

    hr = mediaType->SetGUID(MF_MT_MAJOR_TYPE, MFMediaType_Video);
    if (fallbackInSoftwareModeIfFailed(hr, "MF_MT_MAJOR_TYPE")) {
        return;
    }

    hr = mediaType->SetGUID(MF_MT_SUBTYPE, MFVideoFormat_H264);
    if (fallbackInSoftwareModeIfFailed(hr, "MF_MT_SUBTYPE")) {
        return;
    }

    hr = mediaType->SetUINT32(MF_MT_INTERLACE_MODE, MFVideoInterlace_MixedInterlaceOrProgressive);
    if (fallbackInSoftwareModeIfFailed(hr, "MF_MT_INTERLACE_MODE")) {
        return;
    }

    hr = m_decoderTransform->SetInputType(0, mediaType.Get(), 0);
    if (fallbackInSoftwareModeIfFailed(hr, "SetInputType")) {
        return;
    }

    Microsoft::WRL::ComPtr<IMFMediaType> outputMediaType;
    for (uint32_t i = 0; SUCCEEDED(m_decoderTransform->GetOutputAvailableType(0, i, outputMediaType.GetAddressOf())); ++i) {
        GUID outSubtype = { 0 };
        hr = outputMediaType->GetGUID(MF_MT_SUBTYPE, &outSubtype);
        if (fallbackInSoftwareModeIfFailed(hr, "MF_MT_SUBTYPE")) {
            return;
        }

        if (outSubtype == MFVideoFormat_NV12) {
            hr = m_decoderTransform->SetOutputType(0, outputMediaType.Get(), 0);
            if (fallbackInSoftwareModeIfFailed(hr, "SetOutputType")) {
                return;
            }
            break;
        }
        outputMediaType.Reset();
    }

    Microsoft::WRL::ComPtr<IMFAttributes> outAttributes;
    hr = m_decoderTransform->GetOutputStreamAttributes(0, outAttributes.GetAddressOf());
    if (fallbackInSoftwareModeIfFailed(hr, "GetOutputStreamAttributes")) {
        return;
    }

    outAttributes->SetUINT32(MF_SA_D3D11_BINDFLAGS, D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_DECODER);

    hr = m_decoderTransform->GetInputStreamInfo(0, &m_inputStreamInfo);
    if (fallbackInSoftwareModeIfFailed(hr, "GetInputStreamInfo")) {
        return;
    }

    hr = m_decoderTransform->GetOutputStreamInfo(0, &m_outputStreamInfo);
    if (fallbackInSoftwareModeIfFailed(hr, "GetOutputStreamInfo")) {
        return;
    }
}

MediaResult VideoDecoder::configure(const MediaFormat& input, MediaFormat& output)
{
    TRACE_DEBUG("VideoDecoder::configure - %s", m_softwareDecoder ? "software mode" : "hardware mode");

    if (!input.hasInt(MediaFormat::Video_Width) || !input.hasInt(MediaFormat::Video_Height)) {
        TRACE_ERROR("VideoDecoder::configure - Couldn't figure out output format");
        return MediaResult::ErrorInvalidParameter;
    }

    const int avcLevel = input.getInt(MediaFormat::Video_AVC_Level);
    const int avcProfile = input.getInt(MediaFormat::Video_AVC_Profile);
    output.setInt(MediaFormat::Video_AVC_Level, avcLevel);
    output.setInt(MediaFormat::Video_AVC_Profile, avcProfile);

    const int inputSampleWidth = input.getInt(MediaFormat::Video_Width);
    const int inputSampleHeight = input.getInt(MediaFormat::Video_Height);

    // Assume we want a multiple of 16 always
    const int outputSampleWidth = roundUp(inputSampleWidth, TexturePaddingSize);
    const int outputSampleHeight = roundUp(inputSampleHeight, TexturePaddingSize);

    if (inputSampleWidth != outputSampleWidth || inputSampleHeight != outputSampleHeight) {
        TRACE_INFO("VideoDecoder::configure - Output resolution will be different from Input. %dx%d ==> %dx%d", inputSampleWidth, inputSampleHeight, outputSampleWidth, outputSampleHeight);
    }

    output.setInt(MediaFormat::Video_Width, inputSampleWidth);
    output.setInt(MediaFormat::Video_Height, inputSampleHeight);

    m_currentSampleWidth = outputSampleWidth;
    m_currentSampleHeight = outputSampleHeight;

    return m_softwareDecoder ? configureSoftware(input) : configureHardware();
}

// Release the events that an MFT might allocate in IMFTransform::ProcessOutput().
void releaseEventCollection(DWORD cOutputBuffers, MFT_OUTPUT_DATA_BUFFER* pBuffers)
{
    for (DWORD i = 0; i < cOutputBuffers; i++) {
        if (pBuffers[i].pEvents) {
            pBuffers[i].pEvents->Release();
            pBuffers[i].pEvents = nullptr;
        }
    }
}

void VideoDecoder::requestNewSample()
{
    m_currentSample = m_frameBuffers[m_currentFrameBufferIndex].sample;

    if (++m_currentFrameBufferIndex >= FrameBuffersCount) {
        m_currentFrameBufferIndex = 0;
    }
}

void VideoDecoder::associateOutput(IMFSample* sample)
{
    m_outputSamples.push_back(std::make_shared<VideoSample>(sample));
    requestNewSample();
}

MediaResult VideoDecoder::configureSoftware(const MediaFormat& input)
{
    const int inputSampleWidth = input.getInt(MediaFormat::Video_Width);
    const int inputSampleHeight = input.getInt(MediaFormat::Video_Height);

    // Check input's frame size to see if we need to reconfigure our media types,
    // assuming that the input and output MediaType's frame size are equal
    if (m_inputMediaType) {
        UINT32 inputMediaTypeHeight, inputMediaTypeWidth;
        CHECK_HR(MFGetAttributeSize(m_inputMediaType.Get(), MF_MT_FRAME_SIZE, &inputMediaTypeWidth, &inputMediaTypeHeight), "VideoDecoder::configure(): Failed to get frame size of input media type");

        if (inputSampleHeight == static_cast<int>(inputMediaTypeHeight) && inputSampleWidth == static_cast<int>(inputMediaTypeWidth)) {
            TRACE_INFO("VideoDecoder::configure - Output resolution has not changed. No need to reconfigure.");
            return MediaResult::Ok;
        }
    }

    for (int i = 0; i < FrameBuffersCount; ++i)
    {
        if (m_frameBuffers[i].sample != nullptr) {
            m_frameBuffers[i].sample->RemoveAllBuffers();
        }
        else {
            CHECK_HR(MFCreateSample(&m_frameBuffers[i].sample), "VideoDecoder - Failed to create MF sample.");
        }

        D3D11_TEXTURE2D_DESC descNV12;
        descNV12.Width = m_currentSampleWidth;
        descNV12.Height = m_currentSampleHeight;
        descNV12.MipLevels = descNV12.ArraySize = 1;
        descNV12.Format = DXGI_FORMAT_NV12;
        descNV12.SampleDesc.Count = 1;
        descNV12.Usage = D3D11_USAGE_DEFAULT;
        descNV12.BindFlags = D3D11_BIND_DECODER;
        descNV12.CPUAccessFlags = 0;
        descNV12.MiscFlags = 0;
        descNV12.SampleDesc.Count = 1;
        descNV12.SampleDesc.Quality = 0;

        Microsoft::WRL::ComPtr<ID3D11Texture2D> texture;
        CHECK_HR(m_d3dDevice.Get()->CreateTexture2D(&descNV12, nullptr, &texture), "VideoDecoder - Failed to create texture.");

        Microsoft::WRL::ComPtr<IMFMediaBuffer> buffer;
        CHECK_HR(MFCreateDXGISurfaceBuffer(__uuidof(ID3D11Texture2D), texture.Get(), 0, false, &buffer), "VideoDecoder - Failed to create output memory buffer.");

        CHECK_HR(m_frameBuffers[i].sample->AddBuffer(buffer.Get()), "VideoDecoder - Failed to add sample to buffer.");
    }

    requestNewSample();

    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_SET_D3D_MANAGER, 0), "MFT_MESSAGE_SET_D3D_MANAGER failed");

    CHECK_HR(MFCreateMediaType(&m_inputMediaType), "Failed to create the input MediaType");
    CHECK_HR(m_inputMediaType->SetGUID(MF_MT_MAJOR_TYPE, MFMediaType_Video), "Failed to set MF_MT_MAJOR_TYPE input type.");
    CHECK_HR(m_inputMediaType->SetGUID(MF_MT_SUBTYPE, MFVideoFormat_H264), "Failed to set MF_MT_SUBTYPE input type.");
    CHECK_HR(m_inputMediaType->SetUINT32(MF_MT_ALL_SAMPLES_INDEPENDENT, true), "Failed to set MF_MT_ALL_SAMPLES_INDEPENDENT input type");
    CHECK_HR(m_inputMediaType->SetUINT32(MF_MT_INTERLACE_MODE, MFVideoInterlace_MixedInterlaceOrProgressive), "Failed to set MF_MT_INTERLACE_MODE input type.");
    CHECK_HR(MFSetAttributeSize(m_inputMediaType.Get(), MF_MT_FRAME_SIZE, static_cast<UINT32>(inputSampleWidth), static_cast<UINT32>(inputSampleHeight)), "Failed to set frame size on H.264 MFT input type.");
    CHECK_HR(MFSetAttributeRatio(m_inputMediaType.Get(), MF_MT_PIXEL_ASPECT_RATIO, static_cast<UINT32>(inputSampleWidth), static_cast<UINT32>(inputSampleHeight)), "Failed to set aspect ratio on H.264 MFT input type.");
    CHECK_HR(m_decoderTransform->SetInputType(0, m_inputMediaType.Get(), 0), "Failed to set input media type on H.264 decoder MFT.");

    CHECK_HR(MFCreateMediaType(&m_outputMediaType), "Failed to create the output MediaType");
    CHECK_HR(m_inputMediaType->CopyAllItems(m_outputMediaType.Get()), "Failed to copy the input MediaType to output MediaType");
    CHECK_HR(m_outputMediaType->SetGUID(MF_MT_MAJOR_TYPE, MFMediaType_Video), "Failed to set MF_MT_MAJOR_TYPE output type.");
    CHECK_HR(m_outputMediaType->SetGUID(MF_MT_SUBTYPE, MFVideoFormat_NV12), "Failed to set MF_MT_SUBTYPE out output.");

    // NOTE: Even though we know that the output sample size can change, we
    //      can't manually set the MF_MT_FRAME_SIZE attributes here because
    //      IMFTransform::SetOutputType() fails.
    CHECK_HR(m_decoderTransform->SetOutputType(0, m_outputMediaType.Get(), 0), "Failed to set output media type on H.264 decoder MFT.");

    DWORD mftStatus;
    CHECK_HR(m_decoderTransform->GetInputStatus(0, &mftStatus), "Failed to get input status from H.264 decoder MFT.");

    if (MFT_INPUT_STATUS_ACCEPT_DATA != mftStatus) {
        TRACE_ERROR("H.264 decoder MFT is not accepting data.");
        return MediaResult(MediaResult::ErrorNotSupported, mftStatus);
    }

    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_COMMAND_FLUSH, NULL), "Failed to process FLUSH command on H.264 decoder MFT.");
    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_NOTIFY_BEGIN_STREAMING, NULL), "Failed to process BEGIN_STREAMING command on H.264 decoder MFT.");
    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_NOTIFY_START_OF_STREAM, NULL), "Failed to process START_OF_STREAM command on H.264 decoder MFT.");

    return MediaResult::Ok;
}

MediaResult VideoDecoder::processInputSoftware(const twitch::MediaSampleBuffer& input)
{
    ComPtr<IMFSample> inputSample;
    CHECK_HR(MFCreateSample(&inputSample), "VideoDecoder::processInput - Cannot create input sample for processing");

    // Sample time is in 100-nanoseconds unit. So 0.1 microsecond, or 0.0001 ms
    using namespace std::chrono;
    LONGLONG sampleTimeInHundredsNano = input.presentationTime.nanoseconds().count() / 100;
    LONGLONG sampleDurationInHundredsNano = input.duration.nanoseconds().count() / 100;

    CHECK_HR(inputSample->SetSampleTime(sampleTimeInHundredsNano), "VideoDecoder::processInput - Error setting the input video sample time.");
    CHECK_HR(inputSample->SetSampleDuration(sampleDurationInHundredsNano), "VideoDecoder::processInput - Error setting input video sample duration.");

    ComPtr<IMFMediaBuffer> inputBuffer;
    CHECK_HR(MFCreateMemoryBuffer(static_cast<DWORD>(input.buffer.size()), &inputBuffer), "VideoDecoder::processInput - Failed to create memory buffer.");
    CHECK_HR(inputSample->AddBuffer(inputBuffer.Get()), "Failed to add buffer to re-constructed sample.");

    BYTE* inputByteBuffer;
    DWORD inputBufferLength = 0;
    DWORD inputBufferMaxLength = 0;
    CHECK_HR(inputBuffer->Lock(&inputByteBuffer, &inputBufferMaxLength, &inputBufferLength), "VideoDecoder::processInput - Error locking input buffer.");
    memcpy(inputByteBuffer, input.buffer.data(), input.buffer.size());
    CHECK_HR(inputBuffer->Unlock(), "VideoDecoder::processInput - Error unlocking input buffer.");
    inputBuffer->SetCurrentLength(static_cast<DWORD>(input.buffer.size()));

    m_decoderTransform->SetOutputBounds(input.isDecodeOnly ? sampleTimeInHundredsNano : MFT_OUTPUT_BOUND_LOWER_UNBOUNDED, MFT_OUTPUT_BOUND_UPPER_UNBOUNDED);

    HRESULT hrInput = m_decoderTransform->ProcessInput(0, inputSample.Get(), 0);

    if (hrInput == MF_E_NOTACCEPTING) {
        TRACE_DEBUG("VideoDecoder: MFTTransform is not accepting samples anymore");
        return MediaResult::Ok;
    } else if (FAILED(hrInput)) {
        WindowsPlatform::hError("VideoDecoder::processInput - IMFTransform::ProcessInput failed", hrInput);
        return MediaResult(MediaResult::ErrorInvalidData, hrInput);
    }

    return MediaResult::Ok;
}

MediaResult VideoDecoder::processOutputSoftware(HRESULT& hr, ProcessType pt)
{
    MFT_OUTPUT_STREAM_INFO outputStreamInfo;
    CHECK_HR(m_decoderTransform->GetOutputStreamInfo(0, &outputStreamInfo), "VideoDecoder - Can't figure out the output stream info type");

    // Will the Transform allocate the output sample for us ?
    bool allocateForUs = (outputStreamInfo.dwFlags & MFT_OUTPUT_STREAM_PROVIDES_SAMPLES) || (outputStreamInfo.dwFlags & MFT_OUTPUT_STREAM_CAN_PROVIDE_SAMPLES);

    // Check if we got any output
    MFT_OUTPUT_DATA_BUFFER outputDataBuffer;
    outputDataBuffer.dwStreamID = 0;
    outputDataBuffer.pSample = nullptr;
    outputDataBuffer.dwStatus = 0;
    outputDataBuffer.pEvents = nullptr;

    if (!allocateForUs) {
        outputDataBuffer.pSample = m_currentSample;
    }

    DWORD processOutputStatus = 0;
    hr = m_decoderTransform->ProcessOutput(0, 1, &outputDataBuffer, &processOutputStatus);

    if (hr == MF_E_TRANSFORM_NEED_MORE_INPUT) {
        if (outputDataBuffer.pSample) {
            outputDataBuffer.pSample = nullptr;
        }

        return MediaResult::Ok;
    } else if (hr == MF_E_TRANSFORM_STREAM_CHANGE) {
        assert(processOutputStatus == MFT_PROCESS_OUTPUT_STATUS_NEW_STREAMS);

        // Set the output type to the preferred available stream and try to re-process the output
        // For reference: https://msdn.microsoft.com/en-us/library/windows/desktop/ms699875(v=vs.85).aspx
        ComPtr<IMFMediaType> outputMediaType;
        CHECK_HR(m_decoderTransform->GetOutputAvailableType(0, 0, &m_outputMediaType), "VideoDecoder - Failed to get available input types");
        CHECK_HR(m_decoderTransform->SetOutputType(0, m_outputMediaType.Get(), 0), "VideoDecoder - Failed to set output media type");

        // Retry to process output with new output MediaType
        return processOutputSoftware(hr, pt);
    } else if (FAILED(hr)) {
        WindowsPlatform::hError("VideoDecoder - ProcessOutput failed", hr);
        return MediaResult(MediaResult::ErrorInvalidData, hr);
    }

    if (outputDataBuffer.dwStatus & MFT_OUTPUT_DATA_BUFFER_INCOMPLETE) {
        TRACE_WARN_ONCE("VideoDecoder - ProcessOutput set status to MFT_OUTPUT_DATA_BUFFER_INCOMPLETE - we ignore subsequent samples. Your output may be missing frames.");
    }

    handleOutputDataBuffer(outputDataBuffer, pt);

    return MediaResult::Ok;
}

void VideoDecoder::handleOutputDataBuffer(MFT_OUTPUT_DATA_BUFFER& outputDataBuffer, ProcessType pt)
{
    // Don't care about events, but they must be released according to the docs
    releaseEventCollection(0, &outputDataBuffer);

    if ((outputDataBuffer.dwStatus & MFT_OUTPUT_DATA_BUFFER_STREAM_END)) {
        TRACE_DEBUG("VideoDecoder - Signaled MFT_OUTPUT_DATA_BUFFER_STREAM_END");
    }

    if ((outputDataBuffer.dwStatus & MFT_PROCESS_OUTPUT_STATUS_NEW_STREAMS)) {
        TRACE_DEBUG("VideoDecoder - Signaled MFT_PROCESS_OUTPUT_STATUS_NEW_STREAMS");
    }

    if ((outputDataBuffer.dwStatus & MFT_OUTPUT_DATA_BUFFER_FORMAT_CHANGE)) {
        TRACE_DEBUG("VideoDecoder - Signaled MFT_OUTPUT_DATA_BUFFER_FORMAT_CHANGE");
    }

    if (!(pt & ProcessType::Discard)) {
        associateOutput(outputDataBuffer.pSample);
    }
    else
    {
        if (outputDataBuffer.pSample) {
            outputDataBuffer.pSample = nullptr;
        }
    }
}

MediaResult VideoDecoder::decode(const twitch::MediaSampleBuffer& input)
{
    MediaResult inputResult = m_softwareDecoder ? processInputSoftware(input) : processInputHardware(input);
    return inputResult;
}

MediaResult VideoDecoder::hasOutput(bool& hasOutput)
{
    HRESULT hr;
    MediaResult outputResult = m_softwareDecoder ? processOutputSoftware(hr) : processOutputHardware();
    hasOutput = !m_outputSamples.empty();
    return outputResult;
}

MediaResult VideoDecoder::getOutput(std::shared_ptr<twitch::MediaSample>& output)
{
    if (m_outputSamples.empty()) {
        return MediaResult::ErrorInvalidState;
    } else {
        output = m_outputSamples.front();
        m_outputSamples.pop_front();
        return MediaResult::Ok;
    }
}

MediaResult VideoDecoder::flush()
{
    if (m_softwareDecoder) {
        CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_COMMAND_DRAIN, NULL), "Failed to process DRAIN command on AAC decoder MFT.");
        HRESULT hr;
        MediaResult outputResult;
        while ((outputResult = processOutputSoftware(hr)) == MediaResult::Ok) {
            if (hr == MF_E_TRANSFORM_NEED_MORE_INPUT) {
                break;
            }
        }

        return outputResult;
    }

    return MediaResult::Ok;
}

MediaResult VideoDecoder::reset()
{
    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_COMMAND_FLUSH, NULL), "Failed to process FLUSH command on AAC decoded MFT.");
    m_outputSamples.clear();
    return MediaResult::Ok;
}

bool VideoDecoder::fallbackInSoftwareModeIfFailed(HRESULT hr, const char * message)
{
    if (FAILED(hr)) {
        TRACE_ERROR("%s failed with code %X", message, hr);
        m_softwareDecoder = true;
        return true;
    }

    return false;
}

MediaResult VideoDecoder::configureHardware()
{
    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_NOTIFY_BEGIN_STREAMING, NULL), "Failed to process BEGIN_STREAMING command on H.264 decoder MFT.");
    CHECK_HR(m_decoderTransform->ProcessMessage(MFT_MESSAGE_NOTIFY_START_OF_STREAM, NULL), "Failed to process START_OF_STREAM command on H.264 decoder MFT.");
    return MediaResult::Ok;
}

MediaResult VideoDecoder::processInputHardware(const twitch::MediaSampleBuffer& input)
{
    ComPtr<IMFSample> inputSample;
    ComPtr<IMFMediaBuffer> inputBuffer;

    CHECK_HR(MFCreateSample(&inputSample), "VideoDecoder::processInput - Cannot create input sample for processing");

    if (m_inputStreamInfo.cbAlignment == 0) {
        CHECK_HR(MFCreateMemoryBuffer(static_cast<DWORD>(input.buffer.size()), &inputBuffer), "VideoDecoder::processInput - Failed to create memory buffer.");
    } else {
        CHECK_HR(MFCreateAlignedMemoryBuffer(static_cast<DWORD>(input.buffer.size()), m_inputStreamInfo.cbAlignment - 1, &inputBuffer),
                                                "VideoDecoder::processInput - Failed to create aligned memory buffer.");
    }

    CHECK_HR(inputSample->AddBuffer(inputBuffer.Get()), "Failed to add buffer to re-constructed sample.");
    inputBuffer->SetCurrentLength(0);

    // Sample time is in 100-nanoseconds unit. So 0.1 microsecond, or 0.0001 ms
    LONGLONG sampleTimeInHundredsNano = input.presentationTime.nanoseconds().count() / 100;
    LONGLONG sampleDurationInHundredsNano = input.duration.nanoseconds().count() / 100;

    CHECK_HR(inputSample->SetSampleTime(sampleTimeInHundredsNano), "VideoDecoder::processInput - Error setting the input video sample time.");
    CHECK_HR(inputSample->SetSampleDuration(sampleDurationInHundredsNano), "VideoDecoder::processInput - Error setting input video sample duration.");

    BYTE* inputByteBuffer = nullptr;
    DWORD inputBufferLength = 0;
    DWORD inputBufferMaxLength = 0;
    CHECK_HR(inputBuffer->Lock(&inputByteBuffer, &inputBufferMaxLength, &inputBufferLength), "VideoDecoder::processInput - Error locking input buffer.");
    memcpy(inputByteBuffer, input.buffer.data(), input.buffer.size());
    CHECK_HR(inputBuffer->Unlock(), "VideoDecoder::processInput - Error unlocking input buffer.");
    inputBuffer->SetCurrentLength(static_cast<DWORD>(input.buffer.size()));

    m_decoderTransform->SetOutputBounds(input.isDecodeOnly ? sampleTimeInHundredsNano : MFT_OUTPUT_BOUND_LOWER_UNBOUNDED, MFT_OUTPUT_BOUND_UPPER_UNBOUNDED);

    HRESULT hrInput = m_decoderTransform->ProcessInput(0, inputSample.Get(), 0);
    if (hrInput == MF_E_NOTACCEPTING) {
        processOutputHardware();
        TRACE_DEBUG("VideoDecoder: MFTTransform is not accepting samples anymore");
        return MediaResult::Ok;
    }
    else if (FAILED(hrInput)) {
        WindowsPlatform::hError("VideoDecoder::processInput - IMFTransform::ProcessInput failed", hrInput);
        return MediaResult(MediaResult::ErrorInvalidData, hrInput);
    }

    processOutputHardware();

    return MediaResult::Ok;
}

MediaResult VideoDecoder::processOutputHardware()
{
    Microsoft::WRL::ComPtr<IMFSample> outputSample;
    int retries = 8;
    while (true) {
        outputSample.Reset();
        MFT_OUTPUT_DATA_BUFFER outputDataBuffer = { 0 };
        DWORD status = 0;
        HRESULT hr = m_decoderTransform->ProcessOutput(0, 1, &outputDataBuffer, &status);
        IMFCollection* events = outputDataBuffer.pEvents;

        if (events != nullptr) {
            events->Release();
        }

        outputSample.Attach(outputDataBuffer.pSample);

        if (FAILED(hr)) {
            if (hr == MF_E_TRANSFORM_STREAM_CHANGE) {
                ComPtr<IMFMediaType> outputMediaType;
                CHECK_HR(m_decoderTransform->GetOutputAvailableType(0, 0, &m_outputMediaType), "VideoDecoder - Failed to get available input types");
                CHECK_HR(m_decoderTransform->SetOutputType(0, m_outputMediaType.Get(), 0), "VideoDecoder - Failed to set output media type");

                if (retries-- > 0) {
                    continue;
                }

                return MediaResult::Ok;
            }
            else if (hr == MF_E_TRANSFORM_NEED_MORE_INPUT) {
                return MediaResult::Ok;
            } else {
                WindowsPlatform::hError("VideoDecoder - ProcessOutput failed", hr);
                return MediaResult(MediaResult::ErrorInvalidData, hr);
            }
        }

        break;
    }

    m_samples[m_currentSampleIndex] = outputSample;
    if (++m_currentSampleIndex >= MaxOutputSamples) {
        m_currentSampleIndex = 0;
    }

    m_outputSamples.push_back(std::make_shared<VideoSample>(outputSample.Get()));

    return MediaResult::Ok;
}
